diff --git a/src/server.go b/src/server.go index 80e7e31..88e7c28 100644 --- a/src/server.go +++ b/src/server.go @@ -195,7 +195,7 @@ func (a *App) ServeHTTP(w http.ResponseWriter, r *http.Request) { } // lock the key while a PUT or DELETE is in progress - if r.Method == "POST" || r.Method == "PUT" || r.Method == "DELETE" || r.Method == "UNLINK" || r.Method == "REBALANCE" { + if r.Method == "POST" || r.Method == "PUT" || r.Method == "REPLACE" || r.Method == "DELETE" || r.Method == "UNLINK" || r.Method == "REBALANCE" { if !a.LockKey(lkey) { // Conflict, retry later w.WriteHeader(409) @@ -324,19 +324,21 @@ func (a *App) ServeHTTP(w http.ResponseWriter, r *http.Request) { w.Write([]byte("")) return } - case "PUT": + case "PUT", "REPLACE": // no empty values if r.ContentLength == 0 { w.WriteHeader(411) return } - // check if we already have the key, and it's not deleted - rec := a.GetRecord(key) - if rec.deleted == NO { - // Forbidden to overwrite with PUT - w.WriteHeader(403) - return + if r.Method == "PUT" { + // check if we already have the key, and it's not deleted + rec := a.GetRecord(key) + if rec.deleted == NO { + // Forbidden to overwrite with PUT + w.WriteHeader(403) + return + } } if pn := r.URL.Query().Get("partNumber"); pn != "" { diff --git a/tools/test.py b/tools/test.py index f3ebf66..09e4e89 100755 --- a/tools/test.py +++ b/tools/test.py @@ -124,6 +124,42 @@ def test_head_request(self): # redirect, content length should be size of data self.assertEqual(int(r.headers['content-length']), len(data)) + def test_put_replace(self): + key = self.get_fresh_key() + + r = requests.put(key, data="abc") + self.assertEqual(r.status_code, 201) + + r = requests.request("REPLACE", key, data="123") + self.assertEqual(r.status_code, 201) + + r = requests.get(key) + self.assertEqual(r.status_code, 200) + self.assertEqual(r.text, "123") + + r = requests.delete(key) + self.assertEqual(r.status_code, 204) + + def test_replace_replace(self): + key = self.get_fresh_key() + + r = requests.request("REPLACE", key, data="123") + self.assertEqual(r.status_code, 201) + + r = requests.get(key) + self.assertEqual(r.status_code, 200) + self.assertEqual(r.text, "123") + + r = requests.request("REPLACE", key, data="456") + self.assertEqual(r.status_code, 201) + + r = requests.get(key) + self.assertEqual(r.status_code, 200) + self.assertEqual(r.text, "456") + + r = requests.delete(key) + self.assertEqual(r.status_code, 204) + def test_large_key(self): key = self.get_fresh_key() diff --git a/tools/thrasher.go b/tools/thrasher.go index 4487cdb..6e696ec 100644 --- a/tools/thrasher.go +++ b/tools/thrasher.go @@ -29,8 +29,8 @@ func remote_delete(remote string) error { return nil } -func remote_put(remote string, length int64, body io.Reader) error { - req, err := http.NewRequest("PUT", remote, body) +func remote_set(method string, remote string, length int64, body io.Reader) error { + req, err := http.NewRequest(method, remote, body) if err != nil { return err } @@ -41,7 +41,7 @@ func remote_put(remote string, length int64, body io.Reader) error { } defer resp.Body.Close() if resp.StatusCode != 201 && resp.StatusCode != 204 { - return fmt.Errorf("remote_put: wrong status code %d", resp.StatusCode) + return fmt.Errorf("remote_set: wrong status code %d", resp.StatusCode) } return nil } @@ -76,9 +76,14 @@ func main() { go func() { for { key := <-reqs - value := fmt.Sprintf("value-%d", rand.Int()) - if err := remote_put("http://localhost:3000/"+key, int64(len(value)), strings.NewReader(value)); err != nil { - fmt.Println("PUT FAILED", err) + val := rand.Int() + value := fmt.Sprintf("value-%d", val) + method := "PUT" + if val % 4 == 0 { + method = "REPLACE" + } + if err := remote_set(method, "http://localhost:3000/"+key, int64(len(value)), strings.NewReader(value)); err != nil { + fmt.Println(method + " FAILED", err) resp <- false continue }