Skip to content

Commit

Permalink
Introduce staff-only inline completion provider (#21739)
Browse files Browse the repository at this point in the history
Release Notes:

- N/A

---------

Co-authored-by: Thorsten Ball <[email protected]>
Co-authored-by: Bennet <[email protected]>
Co-authored-by: Thorsten <[email protected]>
  • Loading branch information
4 people authored Dec 9, 2024
1 parent 39e8944 commit 77b8296
Show file tree
Hide file tree
Showing 39 changed files with 2,881 additions and 347 deletions.
42 changes: 40 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ members = [
"crates/worktree",
"crates/zed",
"crates/zed_actions",
"crates/zeta",

#
# Extensions
Expand Down Expand Up @@ -325,6 +326,7 @@ workspace = { path = "crates/workspace" }
worktree = { path = "crates/worktree" }
zed = { path = "crates/zed" }
zed_actions = { path = "crates/zed_actions" }
zeta = { path = "crates/zeta" }

#
# External crates
Expand Down
21 changes: 20 additions & 1 deletion crates/client/src/telemetry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ use std::time::Instant;
use std::{env, mem, path::PathBuf, sync::Arc, time::Duration};
use telemetry_events::{
ActionEvent, AppEvent, AssistantEvent, CallEvent, EditEvent, EditorEvent, Event,
EventRequestBody, EventWrapper, ExtensionEvent, InlineCompletionEvent, ReplEvent, SettingEvent,
EventRequestBody, EventWrapper, ExtensionEvent, InlineCompletionEvent, InlineCompletionRating,
InlineCompletionRatingEvent, ReplEvent, SettingEvent,
};
use util::{ResultExt, TryFutureExt};
use worktree::{UpdatedEntriesSet, WorktreeId};
Expand Down Expand Up @@ -355,6 +356,24 @@ impl Telemetry {
self.report_event(event)
}

pub fn report_inline_completion_rating_event(
self: &Arc<Self>,
rating: InlineCompletionRating,
input_events: Arc<str>,
input_excerpt: Arc<str>,
output_excerpt: Arc<str>,
feedback: String,
) {
let event = Event::InlineCompletionRating(InlineCompletionRatingEvent {
rating,
input_events,
input_excerpt,
output_excerpt,
feedback,
});
self.report_event(event);
}

pub fn report_assistant_event(self: &Arc<Self>, event: AssistantEvent) {
self.report_event(Event::Assistant(event));
}
Expand Down
15 changes: 15 additions & 0 deletions crates/collab/k8s/collab.template.yml
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,21 @@ spec:
secretKeyRef:
name: google-ai
key: api_key
- name: PREDICTION_API_URL
valueFrom:
secretKeyRef:
name: prediction
key: api_url
- name: PREDICTION_API_KEY
valueFrom:
secretKeyRef:
name: prediction
key: api_key
- name: PREDICTION_MODEL
valueFrom:
secretKeyRef:
name: prediction
key: model
- name: BLOB_STORE_ACCESS_KEY
valueFrom:
secretKeyRef:
Expand Down
6 changes: 5 additions & 1 deletion crates/collab/src/api/events.rs
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,7 @@ pub async fn post_events(
checksum_matched,
))
}
Event::Cpu(_) | Event::Memory(_) => continue,
Event::Cpu(_) | Event::Memory(_) | Event::InlineCompletionRating(_) => continue,
Event::App(event) => to_upload.app_events.push(AppEventRow::from_event(
event.clone(),
wrapper,
Expand Down Expand Up @@ -1406,6 +1406,10 @@ fn for_snowflake(
),
serde_json::to_value(e).unwrap(),
),
Event::InlineCompletionRating(e) => (
"Inline Completion Feedback".to_string(),
serde_json::to_value(e).unwrap(),
),
Event::Call(e) => {
let event_type = match e.operation.trim() {
"unshare project" => "Project Unshared".to_string(),
Expand Down
6 changes: 6 additions & 0 deletions crates/collab/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,9 @@ pub struct Config {
pub anthropic_api_key: Option<Arc<str>>,
pub anthropic_staff_api_key: Option<Arc<str>>,
pub llm_closed_beta_model_name: Option<Arc<str>>,
pub prediction_api_url: Option<Arc<str>>,
pub prediction_api_key: Option<Arc<str>>,
pub prediction_model: Option<Arc<str>>,
pub zed_client_checksum_seed: Option<String>,
pub slack_panics_webhook: Option<String>,
pub auto_join_channel_id: Option<ChannelId>,
Expand Down Expand Up @@ -230,6 +233,9 @@ impl Config {
anthropic_api_key: None,
anthropic_staff_api_key: None,
llm_closed_beta_model_name: None,
prediction_api_url: None,
prediction_api_key: None,
prediction_model: None,
clickhouse_url: None,
clickhouse_user: None,
clickhouse_password: None,
Expand Down
59 changes: 58 additions & 1 deletion crates/collab/src/llm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@ use reqwest_client::ReqwestClient;
use rpc::{
proto::Plan, LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME,
};
use rpc::{ListModelsResponse, MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME};
use rpc::{
ListModelsResponse, PredictEditsParams, PredictEditsResponse,
MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME,
};
use serde_json::json;
use std::{
pin::Pin,
Expand Down Expand Up @@ -126,6 +129,7 @@ pub fn routes() -> Router<(), Body> {
Router::new()
.route("/models", get(list_models))
.route("/completion", post(perform_completion))
.route("/predict_edits", post(predict_edits))
.layer(middleware::from_fn(validate_api_token))
}

Expand Down Expand Up @@ -439,6 +443,59 @@ fn normalize_model_name(known_models: Vec<String>, name: String) -> String {
}
}

async fn predict_edits(
Extension(state): Extension<Arc<LlmState>>,
Extension(claims): Extension<LlmTokenClaims>,
_country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
Json(params): Json<PredictEditsParams>,
) -> Result<impl IntoResponse> {
if !claims.is_staff {
return Err(anyhow!("not found"))?;
}

let api_url = state
.config
.prediction_api_url
.as_ref()
.context("no PREDICTION_API_URL configured on the server")?;
let api_key = state
.config
.prediction_api_key
.as_ref()
.context("no PREDICTION_API_KEY configured on the server")?;
let model = state
.config
.prediction_model
.as_ref()
.context("no PREDICTION_MODEL configured on the server")?;
let prompt = include_str!("./llm/prediction_prompt.md")
.replace("<events>", &params.input_events)
.replace("<excerpt>", &params.input_excerpt);
let mut response = open_ai::complete_text(
&state.http_client,
api_url,
api_key,
open_ai::CompletionRequest {
model: model.to_string(),
prompt: prompt.clone(),
max_tokens: 1024,
temperature: 0.,
prediction: Some(open_ai::Prediction::Content {
content: params.input_excerpt,
}),
rewrite_speculation: Some(true),
},
)
.await?;
let choice = response
.choices
.pop()
.context("no output from completion response")?;
Ok(Json(PredictEditsResponse {
output_excerpt: choice.text,
}))
}

/// The maximum monthly spending an individual user can reach on the free tier
/// before they have to pay.
pub const FREE_TIER_MONTHLY_SPENDING_LIMIT: Cents = Cents::from_dollars(10);
Expand Down
12 changes: 12 additions & 0 deletions crates/collab/src/llm/prediction_prompt.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

### Instruction:
You are a code completion assistant and your task is to analyze user edits and then rewrite an excerpt that the user provides, suggesting the appropriate edits within the excerpt, taking into account the cursor location.

### Events:
<events>

### Input:
<excerpt>

### Response:
3 changes: 3 additions & 0 deletions crates/collab/src/tests/test_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,9 @@ impl TestServer {
anthropic_api_key: None,
anthropic_staff_api_key: None,
llm_closed_beta_model_name: None,
prediction_api_url: None,
prediction_api_key: None,
prediction_model: None,
clickhouse_url: None,
clickhouse_user: None,
clickhouse_password: None,
Expand Down
7 changes: 5 additions & 2 deletions crates/copilot/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,18 +59,21 @@ workspace.workspace = true
async-std = { version = "1.12.0", features = ["unstable"] }

[dev-dependencies]
clock.workspace = true
indoc.workspace = true
serde_json.workspace = true
clock = { workspace = true, features = ["test-support"] }
client = { workspace = true, features = ["test-support"] }
collections = { workspace = true, features = ["test-support"] }
editor = { workspace = true, features = ["test-support"] }
fs = { workspace = true, features = ["test-support"] }
gpui = { workspace = true, features = ["test-support"] }
http_client = { workspace = true, features = ["test-support"] }
language = { workspace = true, features = ["test-support"] }
lsp = { workspace = true, features = ["test-support"] }
node_runtime = { workspace = true, features = ["test-support"] }
project = { workspace = true, features = ["test-support"] }
rpc = { workspace = true, features = ["test-support"] }
settings = { workspace = true, features = ["test-support"] }
theme = { workspace = true, features = ["test-support"] }
util = { workspace = true, features = ["test-support"] }
http_client = { workspace = true, features = ["test-support"] }
workspace = { workspace = true, features = ["test-support"] }
Loading

0 comments on commit 77b8296

Please sign in to comment.