-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathservice_ai.go
126 lines (115 loc) · 3.23 KB
/
service_ai.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
package main
import (
"bufio"
"bytes"
"encoding/json"
"fmt"
"net/http"
"strings"
"github.com/gotk3/gotk3/glib"
"github.com/gotk3/gotk3/gtk"
)
type Prompt struct {
MaxTokens int `json:"maxTokens"`
Input string `json:"input"`
Model string `json:"model"`
IgnoreEos bool `json:"ignore_eos"`
TopK int `json:"top_k"`
TopP float64 `json:"top_p"`
Temperature float64 `json:"temperature"`
Mirostat int `json:"mirostat"`
Entropy int `json:"entropy"`
LearningRate float64 `json:"learningRate"`
TailFreeSamplingRate int `json:"tailFreeSamplingRate"`
TypicalP int `json:"typical_p"`
PenalizeNewLines bool `json:"penalizeNewLines"`
PenalizeSpaces bool `json:"penalizeSpaces"`
RepetitionPenalty float64 `json:"repetition_penalty"`
IncludeIngest bool `json:"includeIngest"`
IncludeStatistics bool `json:"includeStatistics"`
OneShot bool `json:"oneShot"`
}
func RunInference(p Prompt, body *gtk.Label) {
payload, _ := json.Marshal(p)
resp, err := http.Post("http://llama.her.st/completion", "application/json", bytes.NewBuffer(payload))
if err != nil {
fmt.Println("Something went wrong with the completion", err)
return
}
defer resp.Body.Close()
reader := bufio.NewReader(resp.Body)
builder := strings.Builder{}
for {
var isComplete = false
buffer := make([]byte, 2)
for i := 0; i < 2; i++ {
b, err := reader.ReadByte()
if err != nil {
isComplete = true
}
buffer[i] = b
}
builder.WriteString(string(buffer))
glib.IdleAdd(func() {
body.SetText(builder.String())
})
if isComplete {
break
}
}
fmt.Println("Finished inference")
}
func GeneratePrompt(mode string, input string, maxToken int, model string, inclIngest bool, inclStats bool) Prompt {
builder := strings.Builder{}
switch mode {
case "completion":
builder.WriteString(input)
case "instruction":
builder.WriteString("### Instruction: ")
builder.WriteRune('\n')
builder.WriteString(input)
builder.WriteRune('\n')
builder.WriteString("### Response:")
builder.WriteRune('\n')
case "chat":
builder.WriteString("User: ")
builder.WriteString(input)
builder.WriteRune('\n')
builder.WriteString("AI:")
}
p := Prompt{
MaxTokens: maxToken,
Input: builder.String(),
Model: model,
IgnoreEos: false,
TopK: 20,
TopP: 0.9,
Temperature: 0.2,
Mirostat: 2,
Entropy: 3,
LearningRate: 0.003,
TailFreeSamplingRate: 1,
TypicalP: 1,
PenalizeNewLines: false,
PenalizeSpaces: false,
RepetitionPenalty: 1.15,
IncludeIngest: inclIngest,
IncludeStatistics: inclStats,
OneShot: true,
}
return p
}
func GetModels() []string {
models := []string{}
resp, err := http.Get("http://llama.her.st/models")
if err != nil {
fmt.Println("Failed to get list of models.", err)
}
defer resp.Body.Close()
scanner := bufio.NewScanner(resp.Body)
for i := 0; scanner.Scan(); i++ {
m := scanner.Text()
models = append(models, m)
}
return models
}