Initial commit

This commit is contained in:
Donny
2019-04-22 20:46:32 +08:00
commit 49ab8aadd1
25441 changed files with 4055000 additions and 0 deletions

View File

@@ -0,0 +1,49 @@
load(
"@io_bazel_rules_go//go:def.bzl",
"go_library",
"go_test",
)
go_library(
name = "go_default_library",
srcs = ["server.go"],
importpath = "go-common/library/net/http",
tags = ["automanaged"],
visibility = ["//visibility:public"],
deps = [
"//library/conf/env:go_default_library",
"//library/log:go_default_library",
"//library/net/http/blademaster:go_default_library",
"//library/time:go_default_library",
"//vendor/github.com/pkg/errors:go_default_library",
],
)
go_test(
name = "go_default_test",
srcs = ["server_test.go"],
embed = [":go_default_library"],
rundir = ".",
tags = ["automanaged"],
deps = [
"//library/net/http/blademaster:go_default_library",
"//library/time:go_default_library",
],
)
filegroup(
name = "package-srcs",
srcs = glob(["**"]),
tags = ["automanaged"],
visibility = ["//visibility:private"],
)
filegroup(
name = "all-srcs",
srcs = [
":package-srcs",
"//library/net/http/blademaster:all-srcs",
],
tags = ["automanaged"],
visibility = ["//visibility:public"],
)

View File

@@ -0,0 +1,26 @@
### http
#### Version 1.3.0
> 1.去掉了handle.go
> 2.server2.go改成serve.goServer2方法改为Serve
#### Version 1.2.2
> 1.支持上报熔断错误到prometheus平台
#### Version 1.2.1
> 1.修复使用了elk默认字段message
#### Version 1.2.0
> 1.拆封Do,JSON,PB,Raw方法
#### Version 1.1.0
> 1.添加http VeriryUser方法
#### Version 1.0.1
> 1. 修复了读取配置时潜在的数据竞争
#### Version 1.0.0
> 1.修复配置了location时breaker不生效的问题
> 2.合并RestfulDo到Do中
> 3.breaker配置只使用最外层的url和host仅配置timeout

View File

@@ -0,0 +1,13 @@
# Owner
maojian
# Author
maojian
zhapuyu
haoguanwei
peiyifei
# Reviewer
maojian
zhapuyu
haoguanwei

12
library/net/http/OWNERS Normal file
View File

@@ -0,0 +1,12 @@
# See the OWNERS docs at https://go.k8s.io/owners
approvers:
- haoguanwei
- maojian
- peiyifei
- zhapuyu
reviewers:
- haoguanwei
- maojian
- peiyifei
- zhapuyu

View File

@@ -0,0 +1,10 @@
# go-common/net/http
#### 项目简介
> 1. 提供http模块
#### 编译环境
> 1. 请只用golang v1.8.x以上版本编译执行。
##### 测试
> 1. 执行当前目录下所有测试文件,测试所有功能

View File

@@ -0,0 +1,116 @@
package(default_visibility = ["//visibility:public"])
load(
"@io_bazel_rules_go//go:def.bzl",
"go_test",
"go_library",
)
go_test(
name = "go_default_test",
srcs = [
"client_test.go",
"server_test.go",
"trace_test.go",
],
embed = [":go_default_library"],
rundir = ".",
tags = ["automanaged"],
deps = [
"//library/ecode:go_default_library",
"//library/log:go_default_library",
"//library/net/http/blademaster/binding:go_default_library",
"//library/net/http/blademaster/render:go_default_library",
"//library/net/http/blademaster/tests:go_default_library",
"//library/net/metadata:go_default_library",
"//library/net/netutil/breaker:go_default_library",
"//library/net/trace:go_default_library",
"//library/time:go_default_library",
"//vendor/github.com/pkg/errors:go_default_library",
"//vendor/github.com/stretchr/testify/assert:go_default_library",
"@com_github_gogo_protobuf//proto:go_default_library",
],
)
go_library(
name = "go_default_library",
srcs = [
"client.go",
"context.go",
"cors.go",
"csrf.go",
"device.go",
"logger.go",
"metadata.go",
"perf.go",
"prometheus.go",
"recovery.go",
"routergroup.go",
"server.go",
"trace.go",
"utils.go",
],
importpath = "go-common/library/net/http/blademaster",
tags = ["automanaged"],
visibility = ["//visibility:public"],
deps = [
"//library/conf/dsn:go_default_library",
"//library/conf/env:go_default_library",
"//library/ecode:go_default_library",
"//library/log:go_default_library",
"//library/net/http/blademaster/binding:go_default_library",
"//library/net/http/blademaster/render:go_default_library",
"//library/net/ip:go_default_library",
"//library/net/metadata:go_default_library",
"//library/net/netutil/breaker:go_default_library",
"//library/net/trace:go_default_library",
"//library/stat:go_default_library",
"//library/time:go_default_library",
"//vendor/github.com/pkg/errors:go_default_library",
"//vendor/github.com/prometheus/client_golang/prometheus/promhttp:go_default_library",
"@com_github_gogo_protobuf//proto:go_default_library",
"@com_github_gogo_protobuf//types:go_default_library",
],
)
go_test(
name = "go_default_xtest",
srcs = ["example_test.go"],
tags = ["automanaged"],
deps = [
"//library/net/http/blademaster:go_default_library",
"//library/net/http/blademaster/binding:go_default_library",
"//library/net/http/blademaster/middleware/auth:go_default_library",
"//library/net/http/blademaster/middleware/verify:go_default_library",
"//library/net/http/blademaster/tests:go_default_library",
],
)
filegroup(
name = "package-srcs",
srcs = glob(["**"]),
tags = ["automanaged"],
visibility = ["//visibility:private"],
)
filegroup(
name = "all-srcs",
srcs = [
":package-srcs",
"//library/net/http/blademaster/binding:all-srcs",
"//library/net/http/blademaster/middleware/antispam:all-srcs",
"//library/net/http/blademaster/middleware/auth:all-srcs",
"//library/net/http/blademaster/middleware/cache:all-srcs",
"//library/net/http/blademaster/middleware/limit/aqm:all-srcs",
"//library/net/http/blademaster/middleware/permit:all-srcs",
"//library/net/http/blademaster/middleware/proxy:all-srcs",
"//library/net/http/blademaster/middleware/rate:all-srcs",
"//library/net/http/blademaster/middleware/supervisor:all-srcs",
"//library/net/http/blademaster/middleware/tag:all-srcs",
"//library/net/http/blademaster/middleware/verify:all-srcs",
"//library/net/http/blademaster/render:all-srcs",
"//library/net/http/blademaster/tests:all-srcs",
],
tags = ["automanaged"],
visibility = ["//visibility:public"],
)

View File

@@ -0,0 +1,49 @@
### net/http/blademaster
##### Version 1.1.4
1. 临时移除 httptrace 避免 datarace
##### Version 1.1.3
1. bind 错误设置到context error
##### Version 1.1.2
1. 将 ecode 作为 header 写入
##### Version 1.1.1
1. device 信息加入metadata
##### Version 1.1.0
1. 对压测流量打标写入md
##### Version 1.0.6
1. 业务错误日志记为 WARN
##### Version 1.0.5
1. 增加 device 中间件
##### Version 1.0.4
1. 增加 metadata 接口,可以获取 Path 和 Method 信息
##### Version 1.0.3
1. 当请求被 CORS 或者 CSRF 模块拒绝后,输出一个 level 为 5 的 Error 日志
##### Version 1.0.2
1. 调整 context.go 里的输出方法参数顺序改为数据在前error 在后
2. Context 里增加 JSONMap 方法,用于适配早期数据结构
3. Recovery 里打印 panic 信息到 stderr
##### Version 1.0.1
1. logger 里增加上报用于监控的 caller
##### Version 1.0.0
1. 完成基本功能与测试

View File

@@ -0,0 +1,8 @@
# Author
maojian
lintnaghui
caoguoliang
zhoujiahui
# Reviewer
maojian

View File

@@ -0,0 +1,12 @@
# See the OWNERS docs at https://go.k8s.io/owners
approvers:
- caoguoliang
- lintnaghui
- maojian
- zhoujiahui
reviewers:
- caoguoliang
- lintnaghui
- maojian
- zhoujiahui

View File

@@ -0,0 +1,19 @@
#### net/http/blademaster
> Blazing fast http framework for humans
##### 项目简介
来自 bilibili 主站技术部的 http 框架,融合主站技术部的核心科技,带来如飞一般的体验。
##### 项目特点
- 模块化设计,核心足够轻量
##### 编译环境
- **请只用 Golang v1.8.x 以上版本编译执行**
##### 依赖包
- No other dependency

View File

@@ -0,0 +1,56 @@
load(
"@io_bazel_rules_go//go:def.bzl",
"go_library",
"go_test",
)
go_library(
name = "go_default_library",
srcs = [
"binding.go",
"default_validator.go",
"form.go",
"form_mapping.go",
"json.go",
"query.go",
"tags.go",
"xml.go",
],
importpath = "go-common/library/net/http/blademaster/binding",
tags = ["automanaged"],
visibility = ["//visibility:public"],
deps = [
"//vendor/github.com/pkg/errors:go_default_library",
"//vendor/gopkg.in/go-playground/validator.v9:go_default_library",
],
)
go_test(
name = "go_default_test",
srcs = [
"binding_test.go",
"example_test.go",
"validate_test.go",
],
embed = [":go_default_library"],
rundir = ".",
tags = ["automanaged"],
deps = ["//vendor/github.com/stretchr/testify/assert:go_default_library"],
)
filegroup(
name = "package-srcs",
srcs = glob(["**"]),
tags = ["automanaged"],
visibility = ["//visibility:private"],
)
filegroup(
name = "all-srcs",
srcs = [
":package-srcs",
"//library/net/http/blademaster/binding/example:all-srcs",
],
tags = ["automanaged"],
visibility = ["//visibility:public"],
)

View File

@@ -0,0 +1,85 @@
package binding
import (
"net/http"
"strings"
"gopkg.in/go-playground/validator.v9"
)
// MIME
const (
MIMEJSON = "application/json"
MIMEHTML = "text/html"
MIMEXML = "application/xml"
MIMEXML2 = "text/xml"
MIMEPlain = "text/plain"
MIMEPOSTForm = "application/x-www-form-urlencoded"
MIMEMultipartPOSTForm = "multipart/form-data"
)
// Binding http binding request interface.
type Binding interface {
Name() string
Bind(*http.Request, interface{}) error
}
// StructValidator http validator interface.
type StructValidator interface {
// ValidateStruct can receive any kind of type and it should never panic, even if the configuration is not right.
// If the received type is not a struct, any validation should be skipped and nil must be returned.
// If the received type is a struct or pointer to a struct, the validation should be performed.
// If the struct is not valid or the validation itself fails, a descriptive error should be returned.
// Otherwise nil must be returned.
ValidateStruct(interface{}) error
// RegisterValidation adds a validation Func to a Validate's map of validators denoted by the key
// NOTE: if the key already exists, the previous validation function will be replaced.
// NOTE: this method is not thread-safe it is intended that these all be registered prior to any validation
RegisterValidation(string, validator.Func) error
}
// Validator default validator.
var Validator StructValidator = &defaultValidator{}
// Binding
var (
JSON = jsonBinding{}
XML = xmlBinding{}
Form = formBinding{}
Query = queryBinding{}
FormPost = formPostBinding{}
FormMultipart = formMultipartBinding{}
)
// Default get by binding type by method and contexttype.
func Default(method, contentType string) Binding {
if method == "GET" {
return Form
}
contentType = stripContentTypeParam(contentType)
switch contentType {
case MIMEJSON:
return JSON
case MIMEXML, MIMEXML2:
return XML
default: //case MIMEPOSTForm, MIMEMultipartPOSTForm:
return Form
}
}
func validate(obj interface{}) error {
if Validator == nil {
return nil
}
return Validator.ValidateStruct(obj)
}
func stripContentTypeParam(contentType string) string {
i := strings.Index(contentType, ";")
if i != -1 {
contentType = contentType[:i]
}
return contentType
}

View File

@@ -0,0 +1,342 @@
package binding
import (
"bytes"
"mime/multipart"
"net/http"
"testing"
"github.com/stretchr/testify/assert"
)
type FooStruct struct {
Foo string `msgpack:"foo" json:"foo" form:"foo" xml:"foo" validate:"required"`
}
type FooBarStruct struct {
FooStruct
Bar string `msgpack:"bar" json:"bar" form:"bar" xml:"bar" validate:"required"`
Slice []string `form:"slice" validate:"max=10"`
}
type ComplexDefaultStruct struct {
Int int `form:"int" default:"999"`
String string `form:"string" default:"default-string"`
Bool bool `form:"bool" default:"false"`
Int64Slice []int64 `form:"int64_slice,split" default:"1,2,3,4"`
Int8Slice []int8 `form:"int8_slice,split" default:"1,2,3,4"`
}
type Int8SliceStruct struct {
State []int8 `form:"state,split"`
}
type Int64SliceStruct struct {
State []int64 `form:"state,split"`
}
type StringSliceStruct struct {
State []string `form:"state,split"`
}
func TestBindingDefault(t *testing.T) {
assert.Equal(t, Default("GET", ""), Form)
assert.Equal(t, Default("GET", MIMEJSON), Form)
assert.Equal(t, Default("GET", MIMEJSON+"; charset=utf-8"), Form)
assert.Equal(t, Default("POST", MIMEJSON), JSON)
assert.Equal(t, Default("PUT", MIMEJSON), JSON)
assert.Equal(t, Default("POST", MIMEJSON+"; charset=utf-8"), JSON)
assert.Equal(t, Default("PUT", MIMEJSON+"; charset=utf-8"), JSON)
assert.Equal(t, Default("POST", MIMEXML), XML)
assert.Equal(t, Default("PUT", MIMEXML2), XML)
assert.Equal(t, Default("POST", MIMEPOSTForm), Form)
assert.Equal(t, Default("PUT", MIMEPOSTForm), Form)
assert.Equal(t, Default("POST", MIMEPOSTForm+"; charset=utf-8"), Form)
assert.Equal(t, Default("PUT", MIMEPOSTForm+"; charset=utf-8"), Form)
assert.Equal(t, Default("POST", MIMEMultipartPOSTForm), Form)
assert.Equal(t, Default("PUT", MIMEMultipartPOSTForm), Form)
}
func TestStripContentType(t *testing.T) {
c1 := "application/vnd.mozilla.xul+xml"
c2 := "application/vnd.mozilla.xul+xml; charset=utf-8"
assert.Equal(t, stripContentTypeParam(c1), c1)
assert.Equal(t, stripContentTypeParam(c2), "application/vnd.mozilla.xul+xml")
}
func TestBindInt8Form(t *testing.T) {
params := "state=1,2,3"
req, _ := http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil)
q := new(Int8SliceStruct)
Form.Bind(req, q)
assert.EqualValues(t, []int8{1, 2, 3}, q.State)
params = "state=1,2,3,256"
req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil)
q = new(Int8SliceStruct)
assert.Error(t, Form.Bind(req, q))
params = "state="
req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil)
q = new(Int8SliceStruct)
assert.NoError(t, Form.Bind(req, q))
assert.Len(t, q.State, 0)
params = "state=1,,2"
req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil)
q = new(Int8SliceStruct)
assert.NoError(t, Form.Bind(req, q))
assert.EqualValues(t, []int8{1, 2}, q.State)
}
func TestBindInt64Form(t *testing.T) {
params := "state=1,2,3"
req, _ := http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil)
q := new(Int64SliceStruct)
Form.Bind(req, q)
assert.EqualValues(t, []int64{1, 2, 3}, q.State)
params = "state="
req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil)
q = new(Int64SliceStruct)
assert.NoError(t, Form.Bind(req, q))
assert.Len(t, q.State, 0)
}
func TestBindStringForm(t *testing.T) {
params := "state=1,2,3"
req, _ := http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil)
q := new(StringSliceStruct)
Form.Bind(req, q)
assert.EqualValues(t, []string{"1", "2", "3"}, q.State)
params = "state="
req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil)
q = new(StringSliceStruct)
assert.NoError(t, Form.Bind(req, q))
assert.Len(t, q.State, 0)
params = "state=p,,p"
req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil)
q = new(StringSliceStruct)
Form.Bind(req, q)
assert.EqualValues(t, []string{"p", "p"}, q.State)
}
func TestBindingJSON(t *testing.T) {
testBodyBinding(t,
JSON, "json",
"/", "/",
`{"foo": "bar"}`, `{"bar": "foo"}`)
}
func TestBindingForm(t *testing.T) {
testFormBinding(t, "POST",
"/", "/",
"foo=bar&bar=foo&slice=a&slice=b", "bar2=foo")
}
func TestBindingForm2(t *testing.T) {
testFormBinding(t, "GET",
"/?foo=bar&bar=foo", "/?bar2=foo",
"", "")
}
func TestBindingQuery(t *testing.T) {
testQueryBinding(t, "POST",
"/?foo=bar&bar=foo", "/",
"foo=unused", "bar2=foo")
}
func TestBindingQuery2(t *testing.T) {
testQueryBinding(t, "GET",
"/?foo=bar&bar=foo", "/?bar2=foo",
"foo=unused", "")
}
func TestBindingXML(t *testing.T) {
testBodyBinding(t,
XML, "xml",
"/", "/",
"<map><foo>bar</foo></map>", "<map><bar>foo</bar></map>")
}
func createFormPostRequest() *http.Request {
req, _ := http.NewRequest("POST", "/?foo=getfoo&bar=getbar", bytes.NewBufferString("foo=bar&bar=foo"))
req.Header.Set("Content-Type", MIMEPOSTForm)
return req
}
func createFormMultipartRequest() *http.Request {
boundary := "--testboundary"
body := new(bytes.Buffer)
mw := multipart.NewWriter(body)
defer mw.Close()
mw.SetBoundary(boundary)
mw.WriteField("foo", "bar")
mw.WriteField("bar", "foo")
req, _ := http.NewRequest("POST", "/?foo=getfoo&bar=getbar", body)
req.Header.Set("Content-Type", MIMEMultipartPOSTForm+"; boundary="+boundary)
return req
}
func TestBindingFormPost(t *testing.T) {
req := createFormPostRequest()
var obj FooBarStruct
FormPost.Bind(req, &obj)
assert.Equal(t, obj.Foo, "bar")
assert.Equal(t, obj.Bar, "foo")
}
func TestBindingFormMultipart(t *testing.T) {
req := createFormMultipartRequest()
var obj FooBarStruct
FormMultipart.Bind(req, &obj)
assert.Equal(t, obj.Foo, "bar")
assert.Equal(t, obj.Bar, "foo")
}
func TestValidationFails(t *testing.T) {
var obj FooStruct
req := requestWithBody("POST", "/", `{"bar": "foo"}`)
err := JSON.Bind(req, &obj)
assert.Error(t, err)
}
func TestValidationDisabled(t *testing.T) {
backup := Validator
Validator = nil
defer func() { Validator = backup }()
var obj FooStruct
req := requestWithBody("POST", "/", `{"bar": "foo"}`)
err := JSON.Bind(req, &obj)
assert.NoError(t, err)
}
func TestExistsSucceeds(t *testing.T) {
type HogeStruct struct {
Hoge *int `json:"hoge" binding:"exists"`
}
var obj HogeStruct
req := requestWithBody("POST", "/", `{"hoge": 0}`)
err := JSON.Bind(req, &obj)
assert.NoError(t, err)
}
func TestFormDefaultValue(t *testing.T) {
params := "int=333&string=hello&bool=true&int64_slice=5,6,7,8&int8_slice=5,6,7,8"
req, _ := http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil)
q := new(ComplexDefaultStruct)
assert.NoError(t, Form.Bind(req, q))
assert.Equal(t, 333, q.Int)
assert.Equal(t, "hello", q.String)
assert.Equal(t, true, q.Bool)
assert.EqualValues(t, []int64{5, 6, 7, 8}, q.Int64Slice)
assert.EqualValues(t, []int8{5, 6, 7, 8}, q.Int8Slice)
params = "string=hello&bool=false"
req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil)
q = new(ComplexDefaultStruct)
assert.NoError(t, Form.Bind(req, q))
assert.Equal(t, 999, q.Int)
assert.Equal(t, "hello", q.String)
assert.Equal(t, false, q.Bool)
assert.EqualValues(t, []int64{1, 2, 3, 4}, q.Int64Slice)
assert.EqualValues(t, []int8{1, 2, 3, 4}, q.Int8Slice)
params = "strings=hello"
req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil)
q = new(ComplexDefaultStruct)
assert.NoError(t, Form.Bind(req, q))
assert.Equal(t, 999, q.Int)
assert.Equal(t, "default-string", q.String)
assert.Equal(t, false, q.Bool)
assert.EqualValues(t, []int64{1, 2, 3, 4}, q.Int64Slice)
assert.EqualValues(t, []int8{1, 2, 3, 4}, q.Int8Slice)
params = "int=&string=&bool=true&int64_slice=&int8_slice="
req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil)
q = new(ComplexDefaultStruct)
assert.NoError(t, Form.Bind(req, q))
assert.Equal(t, 999, q.Int)
assert.Equal(t, "default-string", q.String)
assert.Equal(t, true, q.Bool)
assert.EqualValues(t, []int64{1, 2, 3, 4}, q.Int64Slice)
assert.EqualValues(t, []int8{1, 2, 3, 4}, q.Int8Slice)
}
func testFormBinding(t *testing.T, method, path, badPath, body, badBody string) {
b := Form
assert.Equal(t, b.Name(), "form")
obj := FooBarStruct{}
req := requestWithBody(method, path, body)
if method == "POST" {
req.Header.Add("Content-Type", MIMEPOSTForm)
}
err := b.Bind(req, &obj)
assert.NoError(t, err)
assert.Equal(t, obj.Foo, "bar")
assert.Equal(t, obj.Bar, "foo")
obj = FooBarStruct{}
req = requestWithBody(method, badPath, badBody)
err = JSON.Bind(req, &obj)
assert.Error(t, err)
}
func testQueryBinding(t *testing.T, method, path, badPath, body, badBody string) {
b := Query
assert.Equal(t, b.Name(), "query")
obj := FooBarStruct{}
req := requestWithBody(method, path, body)
if method == "POST" {
req.Header.Add("Content-Type", MIMEPOSTForm)
}
err := b.Bind(req, &obj)
assert.NoError(t, err)
assert.Equal(t, obj.Foo, "bar")
assert.Equal(t, obj.Bar, "foo")
}
func testBodyBinding(t *testing.T, b Binding, name, path, badPath, body, badBody string) {
assert.Equal(t, b.Name(), name)
obj := FooStruct{}
req := requestWithBody("POST", path, body)
err := b.Bind(req, &obj)
assert.NoError(t, err)
assert.Equal(t, obj.Foo, "bar")
obj = FooStruct{}
req = requestWithBody("POST", badPath, badBody)
err = JSON.Bind(req, &obj)
assert.Error(t, err)
}
func requestWithBody(method, path, body string) (req *http.Request) {
req, _ = http.NewRequest(method, path, bytes.NewBufferString(body))
return
}
func BenchmarkBindingForm(b *testing.B) {
req := requestWithBody("POST", "/", "foo=bar&bar=foo&slice=a&slice=b&slice=c&slice=w")
req.Header.Add("Content-Type", MIMEPOSTForm)
f := Form
for i := 0; i < b.N; i++ {
obj := FooBarStruct{}
f.Bind(req, &obj)
}
}

View File

@@ -0,0 +1,45 @@
package binding
import (
"reflect"
"sync"
"gopkg.in/go-playground/validator.v9"
)
type defaultValidator struct {
once sync.Once
validate *validator.Validate
}
var _ StructValidator = &defaultValidator{}
func (v *defaultValidator) ValidateStruct(obj interface{}) error {
if kindOfData(obj) == reflect.Struct {
v.lazyinit()
if err := v.validate.Struct(obj); err != nil {
return err
}
}
return nil
}
func (v *defaultValidator) RegisterValidation(key string, fn validator.Func) error {
v.lazyinit()
return v.validate.RegisterValidation(key, fn)
}
func (v *defaultValidator) lazyinit() {
v.once.Do(func() {
v.validate = validator.New()
})
}
func kindOfData(data interface{}) reflect.Kind {
value := reflect.ValueOf(data)
valueType := value.Kind()
if valueType == reflect.Ptr {
valueType = value.Elem().Kind()
}
return valueType
}

View File

@@ -0,0 +1,45 @@
load(
"@io_bazel_rules_go//go:def.bzl",
"go_library",
)
load(
"@io_bazel_rules_go//proto:def.bzl",
"go_proto_library",
)
proto_library(
name = "example_proto",
srcs = ["test.proto"],
tags = ["automanaged"],
visibility = ["//visibility:public"],
)
go_proto_library(
name = "example_go_proto",
compilers = ["@io_bazel_rules_go//proto:go_proto"],
importpath = "go-common/library/net/http/blademaster/binding/example",
proto = ":example_proto",
tags = ["automanaged"],
visibility = ["//visibility:public"],
)
go_library(
name = "go_default_library",
embed = [":example_go_proto"],
importpath = "go-common/net/http/blademaster/binding/example",
visibility = ["//visibility:public"],
)
filegroup(
name = "package-srcs",
srcs = glob(["**"]),
tags = ["automanaged"],
visibility = ["//visibility:private"],
)
filegroup(
name = "all-srcs",
srcs = [":package-srcs"],
tags = ["automanaged"],
visibility = ["//visibility:public"],
)

View File

@@ -0,0 +1,113 @@
// Code generated by protoc-gen-go.
// source: test.proto
// DO NOT EDIT!
/*
Package example is a generated protocol buffer package.
It is generated from these files:
test.proto
It has these top-level messages:
Test
*/
package example
import proto "github.com/golang/protobuf/proto"
import math "math"
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
var _ = math.Inf
type FOO int32
const (
FOO_X FOO = 17
)
var FOO_name = map[int32]string{
17: "X",
}
var FOO_value = map[string]int32{
"X": 17,
}
func (x FOO) Enum() *FOO {
p := new(FOO)
*p = x
return p
}
func (x FOO) String() string {
return proto.EnumName(FOO_name, int32(x))
}
func (x *FOO) UnmarshalJSON(data []byte) error {
value, err := proto.UnmarshalJSONEnum(FOO_value, data, "FOO")
if err != nil {
return err
}
*x = FOO(value)
return nil
}
type Test struct {
Label *string `protobuf:"bytes,1,req,name=label" json:"label,omitempty"`
Type *int32 `protobuf:"varint,2,opt,name=type,def=77" json:"type,omitempty"`
Reps []int64 `protobuf:"varint,3,rep,name=reps" json:"reps,omitempty"`
Optionalgroup *Test_OptionalGroup `protobuf:"group,4,opt,name=OptionalGroup" json:"optionalgroup,omitempty"`
XXX_unrecognized []byte `json:"-"`
}
func (m *Test) Reset() { *m = Test{} }
func (m *Test) String() string { return proto.CompactTextString(m) }
func (*Test) ProtoMessage() {}
const Default_Test_Type int32 = 77
func (m *Test) GetLabel() string {
if m != nil && m.Label != nil {
return *m.Label
}
return ""
}
func (m *Test) GetType() int32 {
if m != nil && m.Type != nil {
return *m.Type
}
return Default_Test_Type
}
func (m *Test) GetReps() []int64 {
if m != nil {
return m.Reps
}
return nil
}
func (m *Test) GetOptionalgroup() *Test_OptionalGroup {
if m != nil {
return m.Optionalgroup
}
return nil
}
type Test_OptionalGroup struct {
RequiredField *string `protobuf:"bytes,5,req" json:"RequiredField,omitempty"`
XXX_unrecognized []byte `json:"-"`
}
func (m *Test_OptionalGroup) Reset() { *m = Test_OptionalGroup{} }
func (m *Test_OptionalGroup) String() string { return proto.CompactTextString(m) }
func (*Test_OptionalGroup) ProtoMessage() {}
func (m *Test_OptionalGroup) GetRequiredField() string {
if m != nil && m.RequiredField != nil {
return *m.RequiredField
}
return ""
}
func init() {
proto.RegisterEnum("example.FOO", FOO_name, FOO_value)
}

View File

@@ -0,0 +1,12 @@
package example;
enum FOO {X=17;};
message Test {
required string label = 1;
optional int32 type = 2[default=77];
repeated int64 reps = 3;
optional group OptionalGroup = 4{
required string RequiredField = 5;
}
}

View File

@@ -0,0 +1,36 @@
package binding
import (
"fmt"
"log"
"net/http"
)
type Arg struct {
Max int64 `form:"max" validate:"max=10"`
Min int64 `form:"min" validate:"min=2"`
Range int64 `form:"range" validate:"min=1,max=10"`
// use split option to split arg 1,2,3 into slice [1 2 3]
// otherwise slice type with parse url.Values (eg:a=b&a=c) default.
Slice []int64 `form:"slice,split" validate:"min=1"`
}
func ExampleBinding() {
req := initHTTP("max=9&min=3&range=3&slice=1,2,3")
arg := new(Arg)
if err := Form.Bind(req, arg); err != nil {
log.Fatal(err)
}
fmt.Printf("arg.Max %d\narg.Min %d\narg.Range %d\narg.Slice %v", arg.Max, arg.Min, arg.Range, arg.Slice)
// Output:
// arg.Max 9
// arg.Min 3
// arg.Range 3
// arg.Slice [1 2 3]
}
func initHTTP(params string) (req *http.Request) {
req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil)
req.ParseForm()
return
}

View File

@@ -0,0 +1,55 @@
package binding
import (
"net/http"
"github.com/pkg/errors"
)
const defaultMemory = 32 * 1024 * 1024
type formBinding struct{}
type formPostBinding struct{}
type formMultipartBinding struct{}
func (f formBinding) Name() string {
return "form"
}
func (f formBinding) Bind(req *http.Request, obj interface{}) error {
if err := req.ParseForm(); err != nil {
return errors.WithStack(err)
}
if err := mapForm(obj, req.Form); err != nil {
return err
}
return validate(obj)
}
func (f formPostBinding) Name() string {
return "form-urlencoded"
}
func (f formPostBinding) Bind(req *http.Request, obj interface{}) error {
if err := req.ParseForm(); err != nil {
return errors.WithStack(err)
}
if err := mapForm(obj, req.PostForm); err != nil {
return err
}
return validate(obj)
}
func (f formMultipartBinding) Name() string {
return "multipart/form-data"
}
func (f formMultipartBinding) Bind(req *http.Request, obj interface{}) error {
if err := req.ParseMultipartForm(defaultMemory); err != nil {
return errors.WithStack(err)
}
if err := mapForm(obj, req.MultipartForm.Value); err != nil {
return err
}
return validate(obj)
}

View File

@@ -0,0 +1,276 @@
package binding
import (
"reflect"
"strconv"
"strings"
"sync"
"time"
"github.com/pkg/errors"
)
// scache struct reflect type cache.
var scache = &cache{
data: make(map[reflect.Type]*sinfo),
}
type cache struct {
data map[reflect.Type]*sinfo
mutex sync.RWMutex
}
func (c *cache) get(obj reflect.Type) (s *sinfo) {
var ok bool
c.mutex.RLock()
if s, ok = c.data[obj]; !ok {
c.mutex.RUnlock()
s = c.set(obj)
return
}
c.mutex.RUnlock()
return
}
func (c *cache) set(obj reflect.Type) (s *sinfo) {
s = new(sinfo)
tp := obj.Elem()
for i := 0; i < tp.NumField(); i++ {
fd := new(field)
fd.tp = tp.Field(i)
tag := fd.tp.Tag.Get("form")
fd.name, fd.option = parseTag(tag)
if defV := fd.tp.Tag.Get("default"); defV != "" {
dv := reflect.New(fd.tp.Type).Elem()
setWithProperType(fd.tp.Type.Kind(), []string{defV}, dv, fd.option)
fd.hasDefault = true
fd.defaultValue = dv
}
s.field = append(s.field, fd)
}
c.mutex.Lock()
c.data[obj] = s
c.mutex.Unlock()
return
}
type sinfo struct {
field []*field
}
type field struct {
tp reflect.StructField
name string
option tagOptions
hasDefault bool // if field had default value
defaultValue reflect.Value // field default value
}
func mapForm(ptr interface{}, form map[string][]string) error {
sinfo := scache.get(reflect.TypeOf(ptr))
val := reflect.ValueOf(ptr).Elem()
for i, fd := range sinfo.field {
typeField := fd.tp
structField := val.Field(i)
if !structField.CanSet() {
continue
}
structFieldKind := structField.Kind()
inputFieldName := fd.name
if inputFieldName == "" {
inputFieldName = typeField.Name
// if "form" tag is nil, we inspect if the field is a struct.
// this would not make sense for JSON parsing but it does for a form
// since data is flatten
if structFieldKind == reflect.Struct {
err := mapForm(structField.Addr().Interface(), form)
if err != nil {
return err
}
continue
}
}
inputValue, exists := form[inputFieldName]
if !exists {
// Set the field as default value when the input value is not exist
if fd.hasDefault {
structField.Set(fd.defaultValue)
}
continue
}
// Set the field as default value when the input value is empty
if fd.hasDefault && inputValue[0] == "" {
structField.Set(fd.defaultValue)
continue
}
if _, isTime := structField.Interface().(time.Time); isTime {
if err := setTimeField(inputValue[0], typeField, structField); err != nil {
return err
}
continue
}
if err := setWithProperType(typeField.Type.Kind(), inputValue, structField, fd.option); err != nil {
return err
}
}
return nil
}
func setWithProperType(valueKind reflect.Kind, val []string, structField reflect.Value, option tagOptions) error {
switch valueKind {
case reflect.Int:
return setIntField(val[0], 0, structField)
case reflect.Int8:
return setIntField(val[0], 8, structField)
case reflect.Int16:
return setIntField(val[0], 16, structField)
case reflect.Int32:
return setIntField(val[0], 32, structField)
case reflect.Int64:
return setIntField(val[0], 64, structField)
case reflect.Uint:
return setUintField(val[0], 0, structField)
case reflect.Uint8:
return setUintField(val[0], 8, structField)
case reflect.Uint16:
return setUintField(val[0], 16, structField)
case reflect.Uint32:
return setUintField(val[0], 32, structField)
case reflect.Uint64:
return setUintField(val[0], 64, structField)
case reflect.Bool:
return setBoolField(val[0], structField)
case reflect.Float32:
return setFloatField(val[0], 32, structField)
case reflect.Float64:
return setFloatField(val[0], 64, structField)
case reflect.String:
structField.SetString(val[0])
case reflect.Slice:
if option.Contains("split") {
val = strings.Split(val[0], ",")
}
filtered := filterEmpty(val)
switch structField.Type().Elem().Kind() {
case reflect.Int64:
valSli := make([]int64, 0, len(filtered))
for i := 0; i < len(filtered); i++ {
d, err := strconv.ParseInt(filtered[i], 10, 64)
if err != nil {
return err
}
valSli = append(valSli, d)
}
structField.Set(reflect.ValueOf(valSli))
case reflect.String:
valSli := make([]string, 0, len(filtered))
for i := 0; i < len(filtered); i++ {
valSli = append(valSli, filtered[i])
}
structField.Set(reflect.ValueOf(valSli))
default:
sliceOf := structField.Type().Elem().Kind()
numElems := len(filtered)
slice := reflect.MakeSlice(structField.Type(), len(filtered), len(filtered))
for i := 0; i < numElems; i++ {
if err := setWithProperType(sliceOf, filtered[i:], slice.Index(i), ""); err != nil {
return err
}
}
structField.Set(slice)
}
default:
return errors.New("Unknown type")
}
return nil
}
func setIntField(val string, bitSize int, field reflect.Value) error {
if val == "" {
val = "0"
}
intVal, err := strconv.ParseInt(val, 10, bitSize)
if err == nil {
field.SetInt(intVal)
}
return errors.WithStack(err)
}
func setUintField(val string, bitSize int, field reflect.Value) error {
if val == "" {
val = "0"
}
uintVal, err := strconv.ParseUint(val, 10, bitSize)
if err == nil {
field.SetUint(uintVal)
}
return errors.WithStack(err)
}
func setBoolField(val string, field reflect.Value) error {
if val == "" {
val = "false"
}
boolVal, err := strconv.ParseBool(val)
if err == nil {
field.SetBool(boolVal)
}
return nil
}
func setFloatField(val string, bitSize int, field reflect.Value) error {
if val == "" {
val = "0.0"
}
floatVal, err := strconv.ParseFloat(val, bitSize)
if err == nil {
field.SetFloat(floatVal)
}
return errors.WithStack(err)
}
func setTimeField(val string, structField reflect.StructField, value reflect.Value) error {
timeFormat := structField.Tag.Get("time_format")
if timeFormat == "" {
return errors.New("Blank time format")
}
if val == "" {
value.Set(reflect.ValueOf(time.Time{}))
return nil
}
l := time.Local
if isUTC, _ := strconv.ParseBool(structField.Tag.Get("time_utc")); isUTC {
l = time.UTC
}
if locTag := structField.Tag.Get("time_location"); locTag != "" {
loc, err := time.LoadLocation(locTag)
if err != nil {
return errors.WithStack(err)
}
l = loc
}
t, err := time.ParseInLocation(timeFormat, val, l)
if err != nil {
return errors.WithStack(err)
}
value.Set(reflect.ValueOf(t))
return nil
}
func filterEmpty(val []string) []string {
filtered := make([]string, 0, len(val))
for _, v := range val {
if v != "" {
filtered = append(filtered, v)
}
}
return filtered
}

View File

@@ -0,0 +1,22 @@
package binding
import (
"encoding/json"
"net/http"
"github.com/pkg/errors"
)
type jsonBinding struct{}
func (jsonBinding) Name() string {
return "json"
}
func (jsonBinding) Bind(req *http.Request, obj interface{}) error {
decoder := json.NewDecoder(req.Body)
if err := decoder.Decode(obj); err != nil {
return errors.WithStack(err)
}
return validate(obj)
}

View File

@@ -0,0 +1,19 @@
package binding
import (
"net/http"
)
type queryBinding struct{}
func (queryBinding) Name() string {
return "query"
}
func (queryBinding) Bind(req *http.Request, obj interface{}) error {
values := req.URL.Query()
if err := mapForm(obj, values); err != nil {
return err
}
return validate(obj)
}

View File

@@ -0,0 +1,44 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package binding
import (
"strings"
)
// tagOptions is the string following a comma in a struct field's "json"
// tag, or the empty string. It does not include the leading comma.
type tagOptions string
// parseTag splits a struct field's json tag into its name and
// comma-separated options.
func parseTag(tag string) (string, tagOptions) {
if idx := strings.Index(tag, ","); idx != -1 {
return tag[:idx], tagOptions(tag[idx+1:])
}
return tag, tagOptions("")
}
// Contains reports whether a comma-separated list of options
// contains a particular substr flag. substr must be surrounded by a
// string boundary or commas.
func (o tagOptions) Contains(optionName string) bool {
if len(o) == 0 {
return false
}
s := string(o)
for s != "" {
var next string
i := strings.Index(s, ",")
if i >= 0 {
s, next = s[:i], s[i+1:]
}
if s == optionName {
return true
}
s = next
}
return false
}

View File

@@ -0,0 +1,209 @@
package binding
import (
"bytes"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
type testInterface interface {
String() string
}
type substructNoValidation struct {
IString string
IInt int
}
type mapNoValidationSub map[string]substructNoValidation
type structNoValidationValues struct {
substructNoValidation
Boolean bool
Uinteger uint
Integer int
Integer8 int8
Integer16 int16
Integer32 int32
Integer64 int64
Uinteger8 uint8
Uinteger16 uint16
Uinteger32 uint32
Uinteger64 uint64
Float32 float32
Float64 float64
String string
Date time.Time
Struct substructNoValidation
InlinedStruct struct {
String []string
Integer int
}
IntSlice []int
IntPointerSlice []*int
StructPointerSlice []*substructNoValidation
StructSlice []substructNoValidation
InterfaceSlice []testInterface
UniversalInterface interface{}
CustomInterface testInterface
FloatMap map[string]float32
StructMap mapNoValidationSub
}
func createNoValidationValues() structNoValidationValues {
integer := 1
s := structNoValidationValues{
Boolean: true,
Uinteger: 1 << 29,
Integer: -10000,
Integer8: 120,
Integer16: -20000,
Integer32: 1 << 29,
Integer64: 1 << 61,
Uinteger8: 250,
Uinteger16: 50000,
Uinteger32: 1 << 31,
Uinteger64: 1 << 62,
Float32: 123.456,
Float64: 123.456789,
String: "text",
Date: time.Time{},
CustomInterface: &bytes.Buffer{},
Struct: substructNoValidation{},
IntSlice: []int{-3, -2, 1, 0, 1, 2, 3},
IntPointerSlice: []*int{&integer},
StructSlice: []substructNoValidation{},
UniversalInterface: 1.2,
FloatMap: map[string]float32{
"foo": 1.23,
"bar": 232.323,
},
StructMap: mapNoValidationSub{
"foo": substructNoValidation{},
"bar": substructNoValidation{},
},
// StructPointerSlice []noValidationSub
// InterfaceSlice []testInterface
}
s.InlinedStruct.Integer = 1000
s.InlinedStruct.String = []string{"first", "second"}
s.IString = "substring"
s.IInt = 987654
return s
}
func TestValidateNoValidationValues(t *testing.T) {
origin := createNoValidationValues()
test := createNoValidationValues()
empty := structNoValidationValues{}
assert.Nil(t, validate(test))
assert.Nil(t, validate(&test))
assert.Nil(t, validate(empty))
assert.Nil(t, validate(&empty))
assert.Equal(t, origin, test)
}
type structNoValidationPointer struct {
// substructNoValidation
Boolean bool
Uinteger *uint
Integer *int
Integer8 *int8
Integer16 *int16
Integer32 *int32
Integer64 *int64
Uinteger8 *uint8
Uinteger16 *uint16
Uinteger32 *uint32
Uinteger64 *uint64
Float32 *float32
Float64 *float64
String *string
Date *time.Time
Struct *substructNoValidation
IntSlice *[]int
IntPointerSlice *[]*int
StructPointerSlice *[]*substructNoValidation
StructSlice *[]substructNoValidation
InterfaceSlice *[]testInterface
FloatMap *map[string]float32
StructMap *mapNoValidationSub
}
func TestValidateNoValidationPointers(t *testing.T) {
//origin := createNoValidation_values()
//test := createNoValidation_values()
empty := structNoValidationPointer{}
//assert.Nil(t, validate(test))
//assert.Nil(t, validate(&test))
assert.Nil(t, validate(empty))
assert.Nil(t, validate(&empty))
//assert.Equal(t, origin, test)
}
type Object map[string]interface{}
func TestValidatePrimitives(t *testing.T) {
obj := Object{"foo": "bar", "bar": 1}
assert.NoError(t, validate(obj))
assert.NoError(t, validate(&obj))
assert.Equal(t, obj, Object{"foo": "bar", "bar": 1})
obj2 := []Object{{"foo": "bar", "bar": 1}, {"foo": "bar", "bar": 1}}
assert.NoError(t, validate(obj2))
assert.NoError(t, validate(&obj2))
nu := 10
assert.NoError(t, validate(nu))
assert.NoError(t, validate(&nu))
assert.Equal(t, nu, 10)
str := "value"
assert.NoError(t, validate(str))
assert.NoError(t, validate(&str))
assert.Equal(t, str, "value")
}
// structCustomValidation is a helper struct we use to check that
// custom validation can be registered on it.
// The `notone` binding directive is for custom validation and registered later.
// type structCustomValidation struct {
// Integer int `binding:"notone"`
// }
// notOne is a custom validator meant to be used with `validator.v8` library.
// The method signature for `v9` is significantly different and this function
// would need to be changed for tests to pass after upgrade.
// See https://github.com/gin-gonic/gin/pull/1015.
// func notOne(
// v *validator.Validate, topStruct reflect.Value, currentStructOrField reflect.Value,
// field reflect.Value, fieldType reflect.Type, fieldKind reflect.Kind, param string,
// ) bool {
// if val, ok := field.Interface().(int); ok {
// return val != 1
// }
// return false
// }

View File

@@ -0,0 +1,22 @@
package binding
import (
"encoding/xml"
"net/http"
"github.com/pkg/errors"
)
type xmlBinding struct{}
func (xmlBinding) Name() string {
return "xml"
}
func (xmlBinding) Bind(req *http.Request, obj interface{}) error {
decoder := xml.NewDecoder(req.Body)
if err := decoder.Decode(obj); err != nil {
return errors.WithStack(err)
}
return validate(obj)
}

View File

@@ -0,0 +1,430 @@
package blademaster
import (
"bytes"
"context"
"crypto/md5"
"crypto/tls"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"net"
"net/url"
"os"
"runtime"
"strconv"
"strings"
"sync"
"time"
xhttp "net/http"
"go-common/library/conf/env"
"go-common/library/log"
"go-common/library/net/metadata"
"go-common/library/net/netutil/breaker"
"go-common/library/stat"
xtime "go-common/library/time"
"github.com/gogo/protobuf/proto"
pkgerr "github.com/pkg/errors"
)
const (
_minRead = 16 * 1024 // 16kb
_appKey = "appkey"
_appSecret = "appsecret"
_ts = "ts"
)
var (
_noKickUserAgent = "haoguanwei@bilibili.com "
clientStats = stat.HTTPClient
)
func init() {
n, err := os.Hostname()
if err == nil {
_noKickUserAgent = _noKickUserAgent + runtime.Version() + " " + n
}
}
// App bilibili intranet authorization.
type App struct {
Key string
Secret string
}
// ClientConfig is http client conf.
type ClientConfig struct {
*App
Dial xtime.Duration
Timeout xtime.Duration
KeepAlive xtime.Duration
Breaker *breaker.Config
URL map[string]*ClientConfig
Host map[string]*ClientConfig
}
// Client is http client.
type Client struct {
conf *ClientConfig
client *xhttp.Client
dialer *net.Dialer
transport xhttp.RoundTripper
urlConf map[string]*ClientConfig
hostConf map[string]*ClientConfig
mutex sync.RWMutex
breaker *breaker.Group
}
// NewClient new a http client.
func NewClient(c *ClientConfig) *Client {
client := new(Client)
client.conf = c
client.dialer = &net.Dialer{
Timeout: time.Duration(c.Dial),
KeepAlive: time.Duration(c.KeepAlive),
}
originTransport := &xhttp.Transport{
DialContext: client.dialer.DialContext,
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
// wraps RoundTripper for tracer
client.transport = &TraceTransport{RoundTripper: originTransport}
client.client = &xhttp.Client{
Transport: client.transport,
}
client.urlConf = make(map[string]*ClientConfig)
client.hostConf = make(map[string]*ClientConfig)
client.breaker = breaker.NewGroup(c.Breaker)
// check appkey
if c.Key == "" || c.Secret == "" {
panic("http client must config appkey and appsecret")
}
if c.Timeout <= 0 {
panic("must config http timeout!!!")
}
for uri, cfg := range c.URL {
client.urlConf[uri] = cfg
}
for host, cfg := range c.Host {
client.hostConf[host] = cfg
}
return client
}
// SetTransport set client transport
func (client *Client) SetTransport(t xhttp.RoundTripper) {
client.transport = t
client.client.Transport = t
}
// SetConfig set client config.
func (client *Client) SetConfig(c *ClientConfig) {
client.mutex.Lock()
if c.App != nil {
client.conf.App.Key = c.App.Key
client.conf.App.Secret = c.App.Secret
}
if c.Timeout > 0 {
client.conf.Timeout = c.Timeout
}
if c.KeepAlive > 0 {
client.dialer.KeepAlive = time.Duration(c.KeepAlive)
client.conf.KeepAlive = c.KeepAlive
}
if c.Dial > 0 {
client.dialer.Timeout = time.Duration(c.Dial)
client.conf.Timeout = c.Dial
}
if c.Breaker != nil {
client.conf.Breaker = c.Breaker
client.breaker.Reload(c.Breaker)
}
for uri, cfg := range c.URL {
client.urlConf[uri] = cfg
}
for host, cfg := range c.Host {
client.hostConf[host] = cfg
}
client.mutex.Unlock()
}
// NewRequest new http request with method, uri, ip, values and headers.
// TODO(zhoujiahui): param realIP should be removed later.
func (client *Client) NewRequest(method, uri, realIP string, params url.Values) (req *xhttp.Request, err error) {
enc, err := client.sign(params)
if err != nil {
err = pkgerr.Wrapf(err, "uri:%s,params:%v", uri, params)
return
}
ru := uri
if enc != "" {
ru = uri + "?" + enc
}
if method == xhttp.MethodGet {
req, err = xhttp.NewRequest(xhttp.MethodGet, ru, nil)
} else {
req, err = xhttp.NewRequest(xhttp.MethodPost, uri, strings.NewReader(enc))
}
if err != nil {
err = pkgerr.Wrapf(err, "method:%s,uri:%s", method, ru)
return
}
const (
_contentType = "Content-Type"
_urlencoded = "application/x-www-form-urlencoded"
_userAgent = "User-Agent"
)
if method == xhttp.MethodPost {
req.Header.Set(_contentType, _urlencoded)
}
if realIP != "" {
req.Header.Set(_httpHeaderRemoteIP, realIP)
}
req.Header.Set(_userAgent, _noKickUserAgent+" "+env.AppID)
return
}
// Get issues a GET to the specified URL.
func (client *Client) Get(c context.Context, uri, ip string, params url.Values, res interface{}) (err error) {
req, err := client.NewRequest(xhttp.MethodGet, uri, ip, params)
if err != nil {
return
}
return client.Do(c, req, res)
}
// Post issues a Post to the specified URL.
func (client *Client) Post(c context.Context, uri, ip string, params url.Values, res interface{}) (err error) {
req, err := client.NewRequest(xhttp.MethodPost, uri, ip, params)
if err != nil {
return
}
return client.Do(c, req, res)
}
// RESTfulGet issues a RESTful GET to the specified URL.
func (client *Client) RESTfulGet(c context.Context, uri, ip string, params url.Values, res interface{}, v ...interface{}) (err error) {
req, err := client.NewRequest(xhttp.MethodGet, fmt.Sprintf(uri, v...), ip, params)
if err != nil {
return
}
return client.Do(c, req, res, uri)
}
// RESTfulPost issues a RESTful Post to the specified URL.
func (client *Client) RESTfulPost(c context.Context, uri, ip string, params url.Values, res interface{}, v ...interface{}) (err error) {
req, err := client.NewRequest(xhttp.MethodPost, fmt.Sprintf(uri, v...), ip, params)
if err != nil {
return
}
return client.Do(c, req, res, uri)
}
// Raw sends an HTTP request and returns bytes response
func (client *Client) Raw(c context.Context, req *xhttp.Request, v ...string) (bs []byte, err error) {
var (
ok bool
code string
cancel func()
resp *xhttp.Response
config *ClientConfig
timeout time.Duration
uri = fmt.Sprintf("%s://%s%s", req.URL.Scheme, req.Host, req.URL.Path)
)
// NOTE fix prom & config uri key.
if len(v) == 1 {
uri = v[0]
}
// breaker
brk := client.breaker.Get(uri)
if err = brk.Allow(); err != nil {
code = "breaker"
clientStats.Incr(uri, code)
return
}
defer client.onBreaker(brk, &err)
// stat
now := time.Now()
defer func() {
clientStats.Timing(uri, int64(time.Since(now)/time.Millisecond))
if code != "" {
clientStats.Incr(uri, code)
}
}()
// get config
// 1.url config 2.host config 3.default
client.mutex.RLock()
if config, ok = client.urlConf[uri]; !ok {
if config, ok = client.hostConf[req.Host]; !ok {
config = client.conf
}
}
client.mutex.RUnlock()
// timeout
deliver := true
timeout = time.Duration(config.Timeout)
if deadline, ok := c.Deadline(); ok {
if ctimeout := time.Until(deadline); ctimeout < timeout {
// deliver small timeout
timeout = ctimeout
deliver = false
}
}
if deliver {
c, cancel = context.WithTimeout(c, timeout)
defer cancel()
}
setTimeout(req, timeout)
req = req.WithContext(c)
setCaller(req)
if color := metadata.String(c, metadata.Color); color != "" {
setColor(req, color)
}
if resp, err = client.client.Do(req); err != nil {
err = pkgerr.Wrapf(err, "host:%s, url:%s", req.URL.Host, realURL(req))
code = "failed"
return
}
defer resp.Body.Close()
if resp.StatusCode >= xhttp.StatusBadRequest {
err = pkgerr.Errorf("incorrect http status:%d host:%s, url:%s", resp.StatusCode, req.URL.Host, realURL(req))
code = strconv.Itoa(resp.StatusCode)
return
}
if bs, err = readAll(resp.Body, _minRead); err != nil {
err = pkgerr.Wrapf(err, "host:%s, url:%s", req.URL.Host, realURL(req))
return
}
return
}
// Do sends an HTTP request and returns an HTTP json response.
func (client *Client) Do(c context.Context, req *xhttp.Request, res interface{}, v ...string) (err error) {
var bs []byte
if bs, err = client.Raw(c, req, v...); err != nil {
return
}
if res != nil {
if err = json.Unmarshal(bs, res); err != nil {
err = pkgerr.Wrapf(err, "host:%s, url:%s", req.URL.Host, realURL(req))
}
}
return
}
// JSON sends an HTTP request and returns an HTTP json response.
func (client *Client) JSON(c context.Context, req *xhttp.Request, res interface{}, v ...string) (err error) {
var bs []byte
if bs, err = client.Raw(c, req, v...); err != nil {
return
}
if res != nil {
if err = json.Unmarshal(bs, res); err != nil {
err = pkgerr.Wrapf(err, "host:%s, url:%s", req.URL.Host, realURL(req))
}
}
return
}
// PB sends an HTTP request and returns an HTTP proto response.
func (client *Client) PB(c context.Context, req *xhttp.Request, res proto.Message, v ...string) (err error) {
var bs []byte
if bs, err = client.Raw(c, req, v...); err != nil {
return
}
if res != nil {
if err = proto.Unmarshal(bs, res); err != nil {
err = pkgerr.Wrapf(err, "host:%s, url:%s", req.URL.Host, realURL(req))
}
}
return
}
func (client *Client) onBreaker(breaker breaker.Breaker, err *error) {
if err != nil && *err != nil {
breaker.MarkFailed()
} else {
breaker.MarkSuccess()
}
}
// sign calc appkey and appsecret sign.
func (client *Client) sign(params url.Values) (query string, err error) {
client.mutex.RLock()
key := client.conf.Key
secret := client.conf.Secret
client.mutex.RUnlock()
if params == nil {
params = url.Values{}
}
params.Set(_appKey, key)
if params.Get(_appSecret) != "" {
log.Warn("utils http get must not have parameter appSecret")
}
if params.Get(_ts) == "" {
params.Set(_ts, strconv.FormatInt(time.Now().Unix(), 10))
}
tmp := params.Encode()
if strings.IndexByte(tmp, '+') > -1 {
tmp = strings.Replace(tmp, "+", "%20", -1)
}
var b bytes.Buffer
b.WriteString(tmp)
b.WriteString(secret)
mh := md5.Sum(b.Bytes())
// query
var qb bytes.Buffer
qb.WriteString(tmp)
qb.WriteString("&sign=")
qb.WriteString(hex.EncodeToString(mh[:]))
query = qb.String()
return
}
// realUrl return url with http://host/params.
func realURL(req *xhttp.Request) string {
if req.Method == xhttp.MethodGet {
return req.URL.String()
} else if req.Method == xhttp.MethodPost {
ru := req.URL.Path
if req.Body != nil {
rd, ok := req.Body.(io.Reader)
if ok {
buf := bytes.NewBuffer([]byte{})
buf.ReadFrom(rd)
ru = ru + "?" + buf.String()
}
}
return ru
}
return req.URL.Path
}
// readAll reads from r until an error or EOF and returns the data it read
// from the internal buffer allocated with a specified capacity.
func readAll(r io.Reader, capacity int64) (b []byte, err error) {
buf := bytes.NewBuffer(make([]byte, 0, capacity))
// If the buffer overflows, we will get bytes.ErrTooLarge.
// Return that as an error. Any other panic remains.
defer func() {
e := recover()
if e == nil {
return
}
if panicErr, ok := e.(error); ok && panicErr == bytes.ErrTooLarge {
err = panicErr
} else {
panic(e)
}
}()
_, err = buf.ReadFrom(r)
return buf.Bytes(), err
}

View File

@@ -0,0 +1,375 @@
package blademaster
import (
"context"
"net/http"
"net/url"
"strconv"
"testing"
"time"
"go-common/library/ecode"
"go-common/library/net/http/blademaster/tests"
"go-common/library/net/netutil/breaker"
xtime "go-common/library/time"
)
func TestClient(t *testing.T) {
c := &ServerConfig{
Addr: "localhost:8081",
Timeout: xtime.Duration(time.Second),
ReadTimeout: xtime.Duration(time.Second),
WriteTimeout: xtime.Duration(time.Second),
}
engine := Default()
engine.GET("/mytest", func(ctx *Context) {
time.Sleep(time.Millisecond * 500)
ctx.JSON("", nil)
})
engine.GET("/mytest1", func(ctx *Context) {
time.Sleep(time.Millisecond * 500)
ctx.JSON("", nil)
})
engine.SetConfig(c)
engine.Start()
client := NewClient(
&ClientConfig{
App: &App{
Key: "53e2fa226f5ad348",
Secret: "3cf6bd1b0ff671021da5f424fea4b04a",
},
Dial: xtime.Duration(time.Second),
Timeout: xtime.Duration(time.Second),
KeepAlive: xtime.Duration(time.Second),
Breaker: &breaker.Config{
Window: 10 * xtime.Duration(time.Second),
Sleep: 50 * xtime.Duration(time.Millisecond),
Bucket: 10,
Ratio: 0.5,
Request: 100,
},
})
var res struct {
Code int `json:"code"`
}
// test Get
if err := client.Get(context.Background(), "http://api.bilibili.com/x/server/now", "", nil, &res); err != nil {
t.Errorf("HTTPClient: expected no error but got %v", err)
}
if res.Code != 0 {
t.Errorf("HTTPClient: expected code=0 but got %d", res.Code)
}
// test Post
if err := client.Post(context.Background(), "http://api.bilibili.com/x/server/now", "", nil, &res); err != nil {
t.Errorf("HTTPClient: expected no error but got %v", err)
}
if res.Code != -405 {
t.Errorf("HTTPClient: expected code=-405 but got %d", res.Code)
}
// test DialTimeout 172.168.1.1 can't connect.
client.SetConfig(&ClientConfig{Dial: xtime.Duration(time.Second * 5)})
if err := client.Post(context.Background(), "http://172.168.1.1/x/server/now", "", nil, &res); err == nil {
t.Errorf("HTTPClient: expected error but got %v", err)
}
// test server and timeout.
client.SetConfig(&ClientConfig{KeepAlive: xtime.Duration(time.Second * 20), Timeout: xtime.Duration(time.Millisecond * 400)})
if err := client.Get(context.Background(), "http://localhost:8081/mytest", "", nil, &res); err == nil {
t.Errorf("HTTPClient: expected error timeout for request")
}
client.SetConfig(&ClientConfig{Timeout: xtime.Duration(time.Second),
URL: map[string]*ClientConfig{"http://localhost:8081/mytest1": {Timeout: xtime.Duration(time.Millisecond * 300)}}})
if err := client.Get(context.Background(), "http://localhost:8081/mytest", "", nil, &res); err != nil {
t.Errorf("HTTPClient: expected no error but got %v", err)
}
if err := client.Get(context.Background(), "http://localhost:8081/mytest1", "", nil, &res); err == nil {
t.Errorf("HTTPClient: expected error timeout for path")
}
client.SetConfig(&ClientConfig{
Host: map[string]*ClientConfig{"api.bilibili.com": {Timeout: xtime.Duration(time.Millisecond * 300)}},
})
if err := client.Get(context.Background(), "http://api.bilibili.com/x/server/now", "", nil, &res); err != nil {
t.Errorf("HTTPClient: expected no error but got %v", err)
}
client.SetConfig(&ClientConfig{
Host: map[string]*ClientConfig{"api.bilibili.com": {Timeout: xtime.Duration(time.Millisecond * 1)}},
})
if err := client.Get(context.Background(), "http://api.bilibili.com/x/server/now", "", nil, &res); err == nil {
t.Errorf("HTTPClient: expected error timeout but got %v", err)
}
client.SetConfig(&ClientConfig{KeepAlive: xtime.Duration(time.Second * 70)})
}
func TestDo(t *testing.T) {
var (
aid = 5463320
uri = "http://api.bilibili.com/x/server/now"
req *http.Request
client *Client
err error
)
client = NewClient(
&ClientConfig{
App: &App{
Key: "53e2fa226f5ad348",
Secret: "3cf6bd1b0ff671021da5f424fea4b04a",
},
Dial: xtime.Duration(time.Second),
Timeout: xtime.Duration(time.Second),
KeepAlive: xtime.Duration(time.Second),
Breaker: &breaker.Config{
Window: 10 * xtime.Duration(time.Second),
Sleep: 50 * xtime.Duration(time.Millisecond),
Bucket: 10,
Ratio: 0.5,
Request: 100,
},
})
params := url.Values{}
params.Set("aid", strconv.Itoa(aid))
if req, err = client.NewRequest("GET", uri, "", params); err != nil {
t.Errorf("client.NewRequest: get error(%v)", err)
}
var res struct {
Code int `json:"code"`
}
if err = client.Do(context.TODO(), req, &res); err != nil {
t.Errorf("Do: client.Do get error(%v) url: %s", err, realURL(req))
}
}
func BenchmarkDo(b *testing.B) {
once.Do(startServer)
cf := &ClientConfig{
App: &App{
Key: "53e2fa226f5ad348",
Secret: "3cf6bd1b0ff671021da5f424fea4b04a",
},
Dial: xtime.Duration(time.Second),
Timeout: xtime.Duration(time.Second),
KeepAlive: xtime.Duration(time.Second),
Breaker: &breaker.Config{
Window: 1 * xtime.Duration(time.Second),
Sleep: 5 * xtime.Duration(time.Millisecond),
Bucket: 1,
Ratio: 0.5,
Request: 10,
},
URL: map[string]*ClientConfig{
"http://api.bilibili.com/x/server/now": {Timeout: xtime.Duration(time.Second)},
"http://api.bilibili.com/x/server/nowx": {Timeout: xtime.Duration(time.Second)},
},
}
client := NewClient(cf)
uri := "http://api.bilibili.com/x/server/now"
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
// client.SetConfig(cf)
req, err := client.NewRequest("GET", uri, "", nil)
if err != nil {
b.Errorf("newRequest: get error(%v)", err)
continue
}
var res struct {
Code int `json:"code"`
}
if err = client.Do(context.TODO(), req, &res); err != nil {
b.Errorf("Do: client.Do get error(%v) url: %s", err, realURL(req))
}
}
})
uri = "http://api.bilibili.com/x/server/nowx" // NOTE: for breaker
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
// client.SetConfig(cf)
req, err := client.NewRequest("GET", uri, "", nil)
if err != nil {
b.Errorf("newRequest: get error(%v)", err)
continue
}
var res struct {
Code int `json:"code"`
}
if err = client.Do(context.TODO(), req, &res); err != nil {
if ecode.ServiceUnavailable.Equal(err) {
b.Logf("Do: client.Do get error(%v) url: %s", err, realURL(req))
}
}
}
})
}
func TestRESTfulClient(t *testing.T) {
c := &ServerConfig{
Addr: "localhost:8082",
Timeout: xtime.Duration(time.Second),
ReadTimeout: xtime.Duration(time.Second),
WriteTimeout: xtime.Duration(time.Second),
}
engine := Default()
engine.GET("/mytest/1", func(ctx *Context) {
time.Sleep(time.Millisecond * 500)
ctx.JSON("", nil)
})
engine.GET("/mytest/2/1", func(ctx *Context) {
time.Sleep(time.Millisecond * 500)
// ctx.AbortWithStatus(http.StatusInternalServerError)
ctx.JSON(nil, ecode.ServerErr)
})
engine.SetConfig(c)
engine.Start()
client := NewClient(
&ClientConfig{
App: &App{
Key: "53e2fa226f5ad348",
Secret: "3cf6bd1b0ff671021da5f424fea4b04a",
},
Dial: xtime.Duration(time.Second),
Timeout: xtime.Duration(time.Second),
KeepAlive: xtime.Duration(time.Second),
Breaker: &breaker.Config{
Window: 10 * xtime.Duration(time.Second),
Sleep: 50 * xtime.Duration(time.Millisecond),
Bucket: 10,
Ratio: 0.5,
Request: 100,
},
})
var res struct {
Code int `json:"code"`
}
if err := client.RESTfulGet(context.Background(), "http://localhost:8082/mytest/%d", "", nil, &res, 1); err != nil {
t.Errorf("HTTPClient: expected error RESTfulGet err: %v", err)
}
if res.Code != 0 {
t.Errorf("HTTPClient: expected code=0 but got %d", res.Code)
}
if err := client.RESTfulGet(context.Background(), "http://localhost:8082/mytest/%d/%d", "", nil, &res, 2, 1); err != nil {
t.Errorf("HTTPClient: expected error RESTfulGet err: %v", err)
}
if res.Code != -500 {
t.Errorf("HTTPClient: expected code=-500 but got %d", res.Code)
}
}
func TestRaw(t *testing.T) {
var (
aid = 5463320
uri = "http://api.bilibili.com/x/server/now"
req *http.Request
client *Client
err error
)
client = NewClient(
&ClientConfig{
App: &App{
Key: "53e2fa226f5ad348",
Secret: "3cf6bd1b0ff671021da5f424fea4b04a",
},
Dial: xtime.Duration(time.Second),
Timeout: xtime.Duration(time.Second),
KeepAlive: xtime.Duration(time.Second),
Breaker: &breaker.Config{
Window: 10 * xtime.Duration(time.Second),
Sleep: 50 * xtime.Duration(time.Millisecond),
Bucket: 10,
Ratio: 0.5,
Request: 100,
},
})
params := url.Values{}
params.Set("aid", strconv.Itoa(aid))
if req, err = client.NewRequest("GET", uri, "", params); err != nil {
t.Errorf("client.NewRequest: get error(%v)", err)
}
var (
bs []byte
)
if bs, err = client.Raw(context.TODO(), req); err != nil {
t.Errorf("Do: client.Do get error(%v) url: %s", err, realURL(req))
}
t.Log(string(bs))
}
func TestJSON(t *testing.T) {
var (
aid = 5463320
uri = "http://api.bilibili.com/x/server/now"
req *http.Request
client *Client
err error
)
client = NewClient(
&ClientConfig{
App: &App{
Key: "53e2fa226f5ad348",
Secret: "3cf6bd1b0ff671021da5f424fea4b04a",
},
Dial: xtime.Duration(time.Second),
Timeout: xtime.Duration(time.Second),
KeepAlive: xtime.Duration(time.Second),
Breaker: &breaker.Config{
Window: 10 * xtime.Duration(time.Second),
Sleep: 50 * xtime.Duration(time.Millisecond),
Bucket: 10,
Ratio: 0.5,
Request: 100,
},
})
params := url.Values{}
params.Set("aid", strconv.Itoa(aid))
if req, err = client.NewRequest("GET", uri, "", params); err != nil {
t.Errorf("client.NewRequest: get error(%v)", err)
}
var res struct {
Code int `json:"code"`
}
if err = client.Do(context.TODO(), req, &res); err != nil {
t.Errorf("Do: client.Do get error(%v) url: %s", err, realURL(req))
}
}
func TestPB(t *testing.T) {
var (
uri = "http://172.18.33.143:13500/playurl/batch"
req *http.Request
client *Client
err error
)
client = NewClient(
&ClientConfig{
App: &App{
Key: "53e2fa226f5ad348",
Secret: "3cf6bd1b0ff671021da5f424fea4b04a",
},
Dial: xtime.Duration(time.Second),
Timeout: xtime.Duration(time.Second),
KeepAlive: xtime.Duration(time.Second),
Breaker: &breaker.Config{
Window: 10 * xtime.Duration(time.Second),
Sleep: 50 * xtime.Duration(time.Millisecond),
Bucket: 10,
Ratio: 0.5,
Request: 100,
},
})
params := url.Values{}
params.Set("cid", "10108859,10108860")
params.Set("uip", "222.73.196.18")
params.Set("qn", "16")
params.Set("platform", "html5")
params.Set("layout", "pb")
if req, err = client.NewRequest("GET", uri, "", params); err != nil {
t.Errorf("client.NewRequest: get error(%v)", err)
}
var res = new(tests.BvcResponseMsg)
if err = client.PB(context.TODO(), req, res); err != nil {
t.Errorf("Do: client.Do get error(%v) url: %s", err, realURL(req))
}
t.Log(res)
}

View File

@@ -0,0 +1,306 @@
package blademaster
import (
"context"
"math"
"net/http"
"strconv"
"go-common/library/ecode"
"go-common/library/net/http/blademaster/binding"
"go-common/library/net/http/blademaster/render"
"github.com/gogo/protobuf/proto"
"github.com/gogo/protobuf/types"
"github.com/pkg/errors"
)
const (
_abortIndex int8 = math.MaxInt8 / 2
)
var (
_openParen = []byte("(")
_closeParen = []byte(")")
)
// Context is the most important part. It allows us to pass variables between
// middleware, manage the flow, validate the JSON of a request and render a
// JSON response for example.
type Context struct {
context.Context
Request *http.Request
Writer http.ResponseWriter
// flow control
index int8
handlers []HandlerFunc
// Keys is a key/value pair exclusively for the context of each request.
Keys map[string]interface{}
Error error
method string
engine *Engine
}
/************************************/
/*********** FLOW CONTROL ***********/
/************************************/
// Next should be used only inside middleware.
// It executes the pending handlers in the chain inside the calling handler.
// See example in godoc.
func (c *Context) Next() {
c.index++
s := int8(len(c.handlers))
for ; c.index < s; c.index++ {
// only check method on last handler, otherwise middlewares
// will never be effected if request method is not matched
if c.index == s-1 && c.method != c.Request.Method {
code := http.StatusMethodNotAllowed
c.Error = ecode.MethodNotAllowed
http.Error(c.Writer, http.StatusText(code), code)
return
}
c.handlers[c.index](c)
}
}
// Abort prevents pending handlers from being called. Note that this will not stop the current handler.
// Let's say you have an authorization middleware that validates that the current request is authorized.
// If the authorization fails (ex: the password does not match), call Abort to ensure the remaining handlers
// for this request are not called.
func (c *Context) Abort() {
c.index = _abortIndex
}
// AbortWithStatus calls `Abort()` and writes the headers with the specified status code.
// For example, a failed attempt to authenticate a request could use: context.AbortWithStatus(401).
func (c *Context) AbortWithStatus(code int) {
c.Status(code)
c.Abort()
}
// IsAborted returns true if the current context was aborted.
func (c *Context) IsAborted() bool {
return c.index >= _abortIndex
}
/************************************/
/******** METADATA MANAGEMENT********/
/************************************/
// Set is used to store a new key/value pair exclusively for this context.
// It also lazy initializes c.Keys if it was not used previously.
func (c *Context) Set(key string, value interface{}) {
if c.Keys == nil {
c.Keys = make(map[string]interface{})
}
c.Keys[key] = value
}
// Get returns the value for the given key, ie: (value, true).
// If the value does not exists it returns (nil, false)
func (c *Context) Get(key string) (value interface{}, exists bool) {
value, exists = c.Keys[key]
return
}
/************************************/
/******** RESPONSE RENDERING ********/
/************************************/
// bodyAllowedForStatus is a copy of http.bodyAllowedForStatus non-exported function.
func bodyAllowedForStatus(status int) bool {
switch {
case status >= 100 && status <= 199:
return false
case status == 204:
return false
case status == 304:
return false
}
return true
}
// Status sets the HTTP response code.
func (c *Context) Status(code int) {
c.Writer.WriteHeader(code)
}
// Render http response with http code by a render instance.
func (c *Context) Render(code int, r render.Render) {
r.WriteContentType(c.Writer)
if code > 0 {
c.Status(code)
}
if !bodyAllowedForStatus(code) {
return
}
params := c.Request.Form
cb := params.Get("callback")
jsonp := cb != "" && params.Get("jsonp") == "jsonp"
if jsonp {
c.Writer.Write([]byte(cb))
c.Writer.Write(_openParen)
}
if err := r.Render(c.Writer); err != nil {
c.Error = err
return
}
if jsonp {
if _, err := c.Writer.Write(_closeParen); err != nil {
c.Error = errors.WithStack(err)
}
}
}
// JSON serializes the given struct as JSON into the response body.
// It also sets the Content-Type as "application/json".
func (c *Context) JSON(data interface{}, err error) {
code := http.StatusOK
c.Error = err
bcode := ecode.Cause(err)
// TODO app allow 5xx?
/*
if bcode.Code() == -500 {
code = http.StatusServiceUnavailable
}
*/
writeStatusCode(c.Writer, bcode.Code())
c.Render(code, render.JSON{
Code: bcode.Code(),
Message: bcode.Message(),
Data: data,
})
}
// JSONMap serializes the given map as map JSON into the response body.
// It also sets the Content-Type as "application/json".
func (c *Context) JSONMap(data map[string]interface{}, err error) {
code := http.StatusOK
c.Error = err
bcode := ecode.Cause(err)
// TODO app allow 5xx?
/*
if bcode.Code() == -500 {
code = http.StatusServiceUnavailable
}
*/
writeStatusCode(c.Writer, bcode.Code())
data["code"] = bcode.Code()
if _, ok := data["message"]; !ok {
data["message"] = bcode.Message()
}
c.Render(code, render.MapJSON(data))
}
// XML serializes the given struct as XML into the response body.
// It also sets the Content-Type as "application/xml".
func (c *Context) XML(data interface{}, err error) {
code := http.StatusOK
c.Error = err
bcode := ecode.Cause(err)
// TODO app allow 5xx?
/*
if bcode.Code() == -500 {
code = http.StatusServiceUnavailable
}
*/
writeStatusCode(c.Writer, bcode.Code())
c.Render(code, render.XML{
Code: bcode.Code(),
Message: bcode.Message(),
Data: data,
})
}
// Protobuf serializes the given struct as PB into the response body.
// It also sets the ContentType as "application/x-protobuf".
func (c *Context) Protobuf(data proto.Message, err error) {
var (
bytes []byte
)
code := http.StatusOK
c.Error = err
bcode := ecode.Cause(err)
any := new(types.Any)
if data != nil {
if bytes, err = proto.Marshal(data); err != nil {
c.Error = errors.WithStack(err)
return
}
any.TypeUrl = "type.googleapis.com/" + proto.MessageName(data)
any.Value = bytes
}
writeStatusCode(c.Writer, bcode.Code())
c.Render(code, render.PB{
Code: int64(bcode.Code()),
Message: bcode.Message(),
Data: any,
})
}
// Bytes writes some data into the body stream and updates the HTTP code.
func (c *Context) Bytes(code int, contentType string, data ...[]byte) {
c.Render(code, render.Data{
ContentType: contentType,
Data: data,
})
}
// String writes the given string into the response body.
func (c *Context) String(code int, format string, values ...interface{}) {
c.Render(code, render.String{Format: format, Data: values})
}
// Redirect returns a HTTP redirect to the specific location.
func (c *Context) Redirect(code int, location string) {
c.Render(-1, render.Redirect{
Code: code,
Location: location,
Request: c.Request,
})
}
// BindWith bind req arg with parser.
func (c *Context) BindWith(obj interface{}, b binding.Binding) error {
return c.mustBindWith(obj, b)
}
// Bind bind req arg with defult form binding.
func (c *Context) Bind(obj interface{}) error {
return c.mustBindWith(obj, binding.Form)
}
// mustBindWith binds the passed struct pointer using the specified binding engine.
// It will abort the request with HTTP 400 if any error ocurrs.
// See the binding package.
func (c *Context) mustBindWith(obj interface{}, b binding.Binding) (err error) {
if err = b.Bind(c.Request, obj); err != nil {
c.Error = ecode.RequestErr
c.Render(http.StatusOK, render.JSON{
Code: ecode.RequestErr.Code(),
Message: err.Error(),
Data: nil,
})
c.Abort()
}
return
}
func writeStatusCode(w http.ResponseWriter, ecode int) {
header := w.Header()
header.Set("bili-status-code", strconv.FormatInt(int64(ecode), 10))
}

View File

@@ -0,0 +1,260 @@
package blademaster
import (
"net/http"
"strconv"
"strings"
"time"
"go-common/library/log"
"github.com/pkg/errors"
)
var (
allowOriginHosts = []string{
".bilibili.com",
".biligame.com",
".bilibili.co",
".im9.com",
".acg.tv",
".hdslb.com",
}
)
// CORSConfig represents all available options for the middleware.
type CORSConfig struct {
AllowAllOrigins bool
// AllowedOrigins is a list of origins a cross-domain request can be executed from.
// If the special "*" value is present in the list, all origins will be allowed.
// Default value is []
AllowOrigins []string
// AllowOriginFunc is a custom function to validate the origin. It take the origin
// as argument and returns true if allowed or false otherwise. If this option is
// set, the content of AllowedOrigins is ignored.
AllowOriginFunc func(origin string) bool
// AllowedMethods is a list of methods the client is allowed to use with
// cross-domain requests. Default value is simple methods (GET and POST)
AllowMethods []string
// AllowedHeaders is list of non simple headers the client is allowed to use with
// cross-domain requests.
AllowHeaders []string
// AllowCredentials indicates whether the request can include user credentials like
// cookies, HTTP authentication or client side SSL certificates.
AllowCredentials bool
// ExposedHeaders indicates which headers are safe to expose to the API of a CORS
// API specification
ExposeHeaders []string
// MaxAge indicates how long (in seconds) the results of a preflight request
// can be cached
MaxAge time.Duration
}
type cors struct {
allowAllOrigins bool
allowCredentials bool
allowOriginFunc func(string) bool
allowOrigins []string
normalHeaders http.Header
preflightHeaders http.Header
}
type converter func(string) string
// Validate is check configuration of user defined.
func (c *CORSConfig) Validate() error {
if c.AllowAllOrigins && (c.AllowOriginFunc != nil || len(c.AllowOrigins) > 0) {
return errors.New("conflict settings: all origins are allowed. AllowOriginFunc or AllowedOrigins is not needed")
}
if !c.AllowAllOrigins && c.AllowOriginFunc == nil && len(c.AllowOrigins) == 0 {
return errors.New("conflict settings: all origins disabled")
}
for _, origin := range c.AllowOrigins {
if origin != "*" && !strings.HasPrefix(origin, "http://") && !strings.HasPrefix(origin, "https://") {
return errors.New("bad origin: origins must either be '*' or include http:// or https://")
}
}
return nil
}
// CORS returns the location middleware with default configuration.
func CORS() HandlerFunc {
config := &CORSConfig{
AllowMethods: []string{"GET", "POST"},
AllowHeaders: []string{"Origin", "Content-Length", "Content-Type"},
AllowCredentials: true,
MaxAge: time.Duration(0),
AllowOriginFunc: func(origin string) bool {
for _, host := range allowOriginHosts {
if strings.HasSuffix(strings.ToLower(origin), host) {
return true
}
}
return false
},
}
return newCORS(config)
}
// newCORS returns the location middleware with user-defined custom configuration.
func newCORS(config *CORSConfig) HandlerFunc {
if err := config.Validate(); err != nil {
panic(err.Error())
}
cors := &cors{
allowOriginFunc: config.AllowOriginFunc,
allowAllOrigins: config.AllowAllOrigins,
allowCredentials: config.AllowCredentials,
allowOrigins: normalize(config.AllowOrigins),
normalHeaders: generateNormalHeaders(config),
preflightHeaders: generatePreflightHeaders(config),
}
return func(c *Context) {
cors.applyCORS(c)
}
}
func (cors *cors) applyCORS(c *Context) {
origin := c.Request.Header.Get("Origin")
if len(origin) == 0 {
// request is not a CORS request
return
}
if !cors.validateOrigin(origin) {
log.V(5).Info("The request's Origin header `%s` does not match any of allowed origins.", origin)
c.AbortWithStatus(http.StatusForbidden)
return
}
if c.Request.Method == "OPTIONS" {
cors.handlePreflight(c)
defer c.AbortWithStatus(200)
} else {
cors.handleNormal(c)
}
if !cors.allowAllOrigins {
header := c.Writer.Header()
header.Set("Access-Control-Allow-Origin", origin)
}
}
func (cors *cors) validateOrigin(origin string) bool {
if cors.allowAllOrigins {
return true
}
for _, value := range cors.allowOrigins {
if value == origin {
return true
}
}
if cors.allowOriginFunc != nil {
return cors.allowOriginFunc(origin)
}
return false
}
func (cors *cors) handlePreflight(c *Context) {
header := c.Writer.Header()
for key, value := range cors.preflightHeaders {
header[key] = value
}
}
func (cors *cors) handleNormal(c *Context) {
header := c.Writer.Header()
for key, value := range cors.normalHeaders {
header[key] = value
}
}
func generateNormalHeaders(c *CORSConfig) http.Header {
headers := make(http.Header)
if c.AllowCredentials {
headers.Set("Access-Control-Allow-Credentials", "true")
}
// backport support for early browsers
if len(c.AllowMethods) > 0 {
allowMethods := convert(normalize(c.AllowMethods), strings.ToUpper)
value := strings.Join(allowMethods, ",")
headers.Set("Access-Control-Allow-Methods", value)
}
if len(c.ExposeHeaders) > 0 {
exposeHeaders := convert(normalize(c.ExposeHeaders), http.CanonicalHeaderKey)
headers.Set("Access-Control-Expose-Headers", strings.Join(exposeHeaders, ","))
}
if c.AllowAllOrigins {
headers.Set("Access-Control-Allow-Origin", "*")
} else {
headers.Set("Vary", "Origin")
}
return headers
}
func generatePreflightHeaders(c *CORSConfig) http.Header {
headers := make(http.Header)
if c.AllowCredentials {
headers.Set("Access-Control-Allow-Credentials", "true")
}
if len(c.AllowMethods) > 0 {
allowMethods := convert(normalize(c.AllowMethods), strings.ToUpper)
value := strings.Join(allowMethods, ",")
headers.Set("Access-Control-Allow-Methods", value)
}
if len(c.AllowHeaders) > 0 {
allowHeaders := convert(normalize(c.AllowHeaders), http.CanonicalHeaderKey)
value := strings.Join(allowHeaders, ",")
headers.Set("Access-Control-Allow-Headers", value)
}
if c.MaxAge > time.Duration(0) {
value := strconv.FormatInt(int64(c.MaxAge/time.Second), 10)
headers.Set("Access-Control-Max-Age", value)
}
if c.AllowAllOrigins {
headers.Set("Access-Control-Allow-Origin", "*")
} else {
// Always set Vary headers
// see https://github.com/rs/cors/issues/10,
// https://github.com/rs/cors/commit/dbdca4d95feaa7511a46e6f1efb3b3aa505bc43f#commitcomment-12352001
headers.Add("Vary", "Origin")
headers.Add("Vary", "Access-Control-Request-Method")
headers.Add("Vary", "Access-Control-Request-Headers")
}
return headers
}
func normalize(values []string) []string {
if values == nil {
return nil
}
distinctMap := make(map[string]bool, len(values))
normalized := make([]string, 0, len(values))
for _, value := range values {
value = strings.TrimSpace(value)
value = strings.ToLower(value)
if _, seen := distinctMap[value]; !seen {
normalized = append(normalized, value)
distinctMap[value] = true
}
}
return normalized
}
func convert(s []string, c converter) []string {
var out []string
for _, i := range s {
out = append(out, c(i))
}
return out
}

View File

@@ -0,0 +1,89 @@
package blademaster
import (
"net/url"
"regexp"
"strings"
"go-common/library/log"
)
var (
_allowHosts = []string{
".bilibili.com",
".bilibili.co",
".biligame.com",
".im9.com",
".acg.tv",
".hdslb.com",
}
_allowPatterns = []string{
// match by wechat appid
`^http(?:s)?://([\w\d]+\.)?servicewechat.com/(wx7564fd5313d24844|wx618ca8c24bf06c33)`,
}
validations = []func(*url.URL) bool{}
)
func matchHostSuffix(suffix string) func(*url.URL) bool {
return func(uri *url.URL) bool {
return strings.HasSuffix(strings.ToLower(uri.Host), suffix)
}
}
func matchPattern(pattern *regexp.Regexp) func(*url.URL) bool {
return func(uri *url.URL) bool {
return pattern.MatchString(strings.ToLower(uri.String()))
}
}
// addHostSuffix add host suffix into validations
func addHostSuffix(suffix string) {
validations = append(validations, matchHostSuffix(suffix))
}
// addPattern add referer pattern into validations
func addPattern(pattern string) {
validations = append(validations, matchPattern(regexp.MustCompile(pattern)))
}
func init() {
for _, r := range _allowHosts {
addHostSuffix(r)
}
for _, p := range _allowPatterns {
addPattern(p)
}
}
// CSRF returns the csrf middleware to prevent invalid cross site request.
// Only referer is checked currently.
func CSRF() HandlerFunc {
return func(c *Context) {
referer := c.Request.Header.Get("Referer")
params := c.Request.Form
cross := (params.Get("callback") != "" && params.Get("jsonp") == "jsonp") || (params.Get("cross_domain") != "")
if referer == "" {
if !cross {
return
}
log.V(5).Info("The request's Referer header is empty.")
c.AbortWithStatus(403)
return
}
illegal := true
if uri, err := url.Parse(referer); err == nil && uri.Host != "" {
for _, validate := range validations {
if validate(uri) {
illegal = false
break
}
}
}
if illegal {
log.V(5).Info("The request's Referer header `%s` does not match any of allowed referers.", referer)
c.AbortWithStatus(403)
return
}
}
}

View File

@@ -0,0 +1,182 @@
package blademaster
import (
"strconv"
"go-common/library/net/metadata"
)
const (
// PlatAndroid is int8 for android.
PlatAndroid = int8(0)
// PlatIPhone is int8 for iphone.
PlatIPhone = int8(1)
// PlatIPad is int8 for ipad.
PlatIPad = int8(2)
// PlatWPhone is int8 for wphone.
PlatWPhone = int8(3)
// PlatAndroidG is int8 for Android Global.
PlatAndroidG = int8(4)
// PlatIPhoneI is int8 for Iphone Global.
PlatIPhoneI = int8(5)
// PlatIPadI is int8 for IPAD Global.
PlatIPadI = int8(6)
// PlatAndroidTV is int8 for AndroidTV Global.
PlatAndroidTV = int8(7)
// PlatAndroidI is int8 for Android Global.
PlatAndroidI = int8(8)
// PlatAndroidB is int8 for Android Blue.
PlatAndroidB = int8(9)
// PlatIPhoneB is int8 for Ios Blue
PlatIPhoneB = int8(10)
// PlatBilistudio is int8 for bilistudio
PlatBilistudio = int8(11)
// PlatAndroidTVYST is int8 for AndroidTV_YST Global.
PlatAndroidTVYST = int8(12)
)
// Device is the mobile device model
type Device struct {
Build int64
Buvid string
Buvid3 string
Channel string
Device string
Sid string
RawPlatform string
RawMobiApp string
}
// Mobile is the default handler
func Mobile() HandlerFunc {
return func(ctx *Context) {
req := ctx.Request
dev := new(Device)
dev.Buvid = req.Header.Get("Buvid")
if buvid3, err := req.Cookie("buvid3"); err == nil && buvid3 != nil {
dev.Buvid3 = buvid3.Value
}
if sid, err := req.Cookie("sid"); err == nil && sid != nil {
dev.Sid = sid.Value
}
if build, err := strconv.ParseInt(req.Form.Get("build"), 10, 64); err == nil {
dev.Build = build
}
dev.Channel = req.Form.Get("channel")
dev.Device = req.Form.Get("device")
dev.RawMobiApp = req.Form.Get("mobi_app")
dev.RawPlatform = req.Form.Get("platform")
ctx.Set("device", dev)
if md, ok := metadata.FromContext(ctx); ok {
md[metadata.Device] = dev
}
}
}
// Plat return platform from raw platform and mobiApp
func (d *Device) Plat() int8 {
switch d.RawMobiApp {
case "iphone":
if d.Device == "pad" {
return PlatIPad
}
return PlatIPhone
case "white":
return PlatIPhone
case "ipad":
return PlatIPad
case "android":
return PlatAndroid
case "win":
return PlatWPhone
case "android_G":
return PlatAndroidG
case "android_i":
return PlatAndroidI
case "android_b":
return PlatAndroidB
case "iphone_i":
if d.Device == "pad" {
return PlatIPadI
}
return PlatIPhoneI
case "ipad_i":
return PlatIPadI
case "iphone_b":
return PlatIPhoneB
case "android_tv":
return PlatAndroidTV
case "android_tv_yst":
return PlatAndroidTVYST
case "bilistudio":
return PlatBilistudio
}
return PlatIPhone
}
// IsAndroid check plat is android or ipad.
func (d *Device) IsAndroid() bool {
plat := d.Plat()
return plat == PlatAndroid ||
plat == PlatAndroidG ||
plat == PlatAndroidB ||
plat == PlatAndroidI ||
plat == PlatBilistudio ||
plat == PlatAndroidTV ||
plat == PlatAndroidTVYST
}
// IsIOS check plat is iphone or ipad.
func (d *Device) IsIOS() bool {
plat := d.Plat()
return plat == PlatIPad ||
plat == PlatIPhone ||
plat == PlatIPadI ||
plat == PlatIPhoneI ||
plat == PlatIPhoneB
}
// IsOverseas is overseas
func (d *Device) IsOverseas() bool {
plat := d.Plat()
return plat == PlatAndroidI || plat == PlatIPhoneI || plat == PlatIPadI
}
// InvalidChannel check source channel is not allow by config channel.
func (d *Device) InvalidChannel(cfgCh string) bool {
plat := d.Plat()
return plat == PlatAndroid && cfgCh != "*" && cfgCh != d.Channel
}
// MobiApp by plat
func (d *Device) MobiApp() string {
plat := d.Plat()
switch plat {
case PlatAndroid:
return "android"
case PlatIPhone:
return "iphone"
case PlatIPad:
return "ipad"
case PlatAndroidI:
return "android_i"
case PlatIPhoneI:
return "iphone_i"
case PlatIPadI:
return "ipad_i"
case PlatAndroidG:
return "android_G"
}
return "iphone"
}
// MobiAPPBuleChange is app blue change.
func (d *Device) MobiAPPBuleChange() string {
switch d.RawMobiApp {
case "android_b":
return "android"
case "iphone_b":
return "iphone"
}
return d.RawMobiApp
}

View File

@@ -0,0 +1,230 @@
package blademaster_test
import (
"io/ioutil"
"log"
"time"
"go-common/library/net/http/blademaster"
"go-common/library/net/http/blademaster/binding"
"go-common/library/net/http/blademaster/middleware/auth"
"go-common/library/net/http/blademaster/middleware/verify"
"go-common/library/net/http/blademaster/tests"
)
// This example start a http server and listen at port 8080,
// it will handle '/ping' and return response in html text
func Example() {
engine := blademaster.Default()
engine.GET("/ping", func(c *blademaster.Context) {
c.String(200, "%s", "pong")
})
engine.Run(":8080")
}
// This example use `RouterGroup` to separate different requests,
// it will handle ('/group1/ping', '/group2/ping') and return response in json
func ExampleRouterGroup() {
engine := blademaster.Default()
group := engine.Group("/group1", blademaster.CORS())
group.GET("/ping", func(c *blademaster.Context) {
c.JSON(map[string]string{"message": "hello"}, nil)
})
group2 := engine.Group("/group2", blademaster.CORS())
group2.GET("/ping", func(c *blademaster.Context) {
c.JSON(map[string]string{"message": "welcome"}, nil)
})
engine.Run(":8080")
}
// This example add two middlewares in the root router by `Use` method,
// it will add CORS headers in response and log total consumed time
func ExampleEngine_Use() {
timeLogger := func() blademaster.HandlerFunc {
return func(c *blademaster.Context) {
start := time.Now()
c.Next()
log.Printf("total consume: %v", time.Since(start))
}
}
engine := blademaster.Default()
engine.Use(blademaster.CORS())
engine.Use(timeLogger())
engine.GET("/ping", func(c *blademaster.Context) {
c.String(200, "%s", "pong")
})
engine.Run(":8080")
}
// This example add two middlewares in the root router by `UseFunc` method,
// it will log total consumed time
func ExampleEngine_UseFunc() {
engine := blademaster.Default()
engine.UseFunc(func(c *blademaster.Context) {
start := time.Now()
c.Next()
log.Printf("total consume: %v", time.Since(start))
})
engine.GET("/ping", func(c *blademaster.Context) {
c.String(200, "%s", "pong")
})
engine.Run(":8080")
}
// This example start a http server through the specified unix socket,
// it will handle '/ping' and return reponse in html text
func ExampleEngine_RunUnix() {
engine := blademaster.Default()
engine.GET("/ping", func(c *blademaster.Context) {
c.String(200, "%s", "pong")
})
unixs, err := ioutil.TempFile("", "engine.sock")
if err != nil {
log.Fatalf("Failed to create temp file: %s", err)
}
if err := engine.RunUnix(unixs.Name()); err != nil {
log.Fatalf("Failed to serve with unix socket: %s", err)
}
}
// This example show how to render response in json format,
// it will render structures as json like: `{"code":0,"message":"0","data":{"Time":"2017-11-14T23:03:22.0523199+08:00"}}`
func ExampleContext_JSON() {
type Data struct {
Time time.Time
}
engine := blademaster.Default()
engine.GET("/ping", func(c *blademaster.Context) {
var d Data
d.Time = time.Now()
c.JSON(d, nil)
})
engine.Run(":8080")
}
// This example show how to render response in protobuf format
// it will marshal whole response content to protobuf
func ExampleContext_Protobuf() {
engine := blademaster.Default()
engine.GET("/ping.pb", func(c *blademaster.Context) {
t := &tests.Time{
Now: time.Now().Unix(),
}
c.Protobuf(t, nil)
})
engine.Run(":8080")
}
// This example show how to render response in XML format,
// it will render structure as XML like: `<Data><Time>2017-11-14T23:03:49.2231458+08:00</Time></Data>`
func ExampleContext_XML() {
type Data struct {
Time time.Time
}
engine := blademaster.Default()
engine.GET("/ping", func(c *blademaster.Context) {
var d Data
d.Time = time.Now()
c.XML(d, nil)
})
engine.Run(":8080")
}
// This example show how to protect your handlers by HTTP basic auth,
// it will validate the baisc auth and abort with status 403 if authentication is invalid
func ExampleContext_Abort() {
engine := blademaster.Default()
engine.UseFunc(func(c *blademaster.Context) {
user, pass, isok := c.Request.BasicAuth()
if !isok || user != "root" || pass != "root" {
c.AbortWithStatus(403)
return
}
})
engine.GET("/auth", func(c *blademaster.Context) {
c.String(200, "%s", "Welcome")
})
engine.Run(":8080")
}
// This example show how to using the default parameter binding to parse the url param from get request,
// it will validate the request and abort with status 400 if params is invalid
func ExampleContext_Bind() {
engine := blademaster.Default()
engine.GET("/bind", func(c *blademaster.Context) {
v := new(struct {
// This mark field `mids` should exist and every element should greater than 1
Mids []int64 `form:"mids" validate:"dive,gt=1,required"`
Title string `form:"title" validate:"required"`
Content string `form:"content"`
// This mark field `cid` should between 1 and 10
Cid int `form:"cid" validate:"min=1,max=10"`
})
err := c.Bind(v)
if err != nil {
// Do not call any write response method in this state,
// the response body is already written in `c.BindWith` method
return
}
c.String(200, "parse params by bind %+v", v)
})
engine.Run(":8080")
}
// This example show how to using the json binding to parse the json param from post request body,
// it will validate the request and abort with status 400 if params is invalid
func ExampleContext_BindWith() {
engine := blademaster.Default()
engine.POST("/bindwith", func(c *blademaster.Context) {
v := new(struct {
// This mark field `mids` should exist and every element should greater than 1
Mids []int64 `json:"mids" validate:"dive,gt=1,required"`
Title string `json:"title" validate:"required"`
Content string `json:"content"`
// This mark field `cid` should between 1 and 10
Cid int `json:"cid" validate:"min=1,max=10"`
})
err := c.BindWith(v, binding.JSON)
if err != nil {
// Do not call any write response method in this state,
// the response body is already written in `c.BindWith` method
return
}
c.String(200, "parse params by bindwith %+v", v)
})
engine.Run(":8080")
}
func ExampleEngine_Inject() {
v := verify.New(nil)
auth := auth.New(nil)
engine := blademaster.Default()
engine.Inject("^/index", v.Verify, auth.User)
engine.POST("/index/hello", func(c *blademaster.Context) {
c.JSON("hello, world", nil)
})
engine.Run(":8080")
}

View File

@@ -0,0 +1,71 @@
package blademaster
import (
"fmt"
"strconv"
"time"
"go-common/library/ecode"
"go-common/library/log"
"go-common/library/net/metadata"
)
// Logger is logger middleware
func Logger() HandlerFunc {
const noUser = "no_user"
return func(c *Context) {
now := time.Now()
ip := metadata.String(c, metadata.RemoteIP)
req := c.Request
path := req.URL.Path
params := req.Form
var quota float64
if deadline, ok := c.Context.Deadline(); ok {
quota = time.Until(deadline).Seconds()
}
c.Next()
mid, _ := c.Get("mid")
err := c.Error
cerr := ecode.Cause(err)
dt := time.Since(now)
caller := metadata.String(c, metadata.Caller)
if caller == "" {
caller = noUser
}
stats.Incr(caller, path[1:], strconv.FormatInt(int64(cerr.Code()), 10))
stats.Timing(caller, int64(dt/time.Millisecond), path[1:])
lf := log.Infov
errmsg := ""
isSlow := dt >= (time.Millisecond * 500)
if err != nil {
errmsg = err.Error()
lf = log.Errorv
if cerr.Code() > 0 {
lf = log.Warnv
}
} else {
if isSlow {
lf = log.Warnv
}
}
lf(c,
log.KV("method", req.Method),
log.KV("mid", mid),
log.KV("ip", ip),
log.KV("user", caller),
log.KV("path", path),
log.KV("params", params.Encode()),
log.KV("ret", cerr.Code()),
log.KV("msg", cerr.Message()),
log.KV("stack", fmt.Sprintf("%+v", err)),
log.KV("err", errmsg),
log.KV("timeout_quota", quota),
log.KV("ts", dt.Seconds()),
log.KV("source", "http-access-log"),
)
}
}

View File

@@ -0,0 +1,106 @@
package blademaster
import (
"net/http"
"strconv"
"strings"
"time"
"go-common/library/conf/env"
"go-common/library/log"
"github.com/pkg/errors"
)
const (
// http head
_httpHeaderUser = "x1-bilispy-user"
_httpHeaderColor = "x1-bilispy-color"
_httpHeaderTimeout = "x1-bilispy-timeout"
_httpHeaderRemoteIP = "x-backend-bili-real-ip"
_httpHeaderRemoteIPPort = "x-backend-bili-real-ipport"
)
// mirror return true if x1-bilispy-mirror in http header and its value is 1 or true.
func mirror(req *http.Request) bool {
mirrorStr := req.Header.Get("x1-bilispy-mirror")
if mirrorStr == "" {
return false
}
val, err := strconv.ParseBool(mirrorStr)
if err != nil {
log.Warn("blademaster: failed to parse mirror: %+v", errors.Wrap(err, mirrorStr))
return false
}
if !val {
log.Warn("blademaster: request mirrorStr value :%s is false", mirrorStr)
}
return val
}
// setCaller set caller into http request.
func setCaller(req *http.Request) {
req.Header.Set(_httpHeaderUser, env.AppID)
}
// caller get caller from http request.
func caller(req *http.Request) string {
return req.Header.Get(_httpHeaderUser)
}
// setColor set color into http request.
func setColor(req *http.Request, color string) {
req.Header.Set(_httpHeaderColor, color)
}
// color get color from http request.
func color(req *http.Request) string {
c := req.Header.Get(_httpHeaderColor)
if c == "" {
c = env.Color
}
return c
}
// setTimeout set timeout into http request.
func setTimeout(req *http.Request, timeout time.Duration) {
td := int64(timeout / time.Millisecond)
req.Header.Set(_httpHeaderTimeout, strconv.FormatInt(td, 10))
}
// timeout get timeout from http request.
func timeout(req *http.Request) time.Duration {
to := req.Header.Get(_httpHeaderTimeout)
timeout, err := strconv.ParseInt(to, 10, 64)
if err == nil && timeout > 20 {
timeout -= 20 // reduce 20ms every time.
}
return time.Duration(timeout) * time.Millisecond
}
// remoteIP implements a best effort algorithm to return the real client IP, it parses
// X-BACKEND-BILI-REAL-IP or X-Real-IP or X-Forwarded-For in order to work properly with reverse-proxies such us: nginx or haproxy.
// Use X-Forwarded-For before X-Real-Ip as nginx uses X-Real-Ip with the proxy's IP.
func remoteIP(req *http.Request) (remote string) {
if remote = req.Header.Get(_httpHeaderRemoteIP); remote != "" && remote != "null" {
return
}
var xff = req.Header.Get("X-Forwarded-For")
if idx := strings.IndexByte(xff, ','); idx > -1 {
if remote = strings.TrimSpace(xff[:idx]); remote != "" {
return
}
}
if remote = req.Header.Get("X-Real-IP"); remote != "" {
return
}
remote = req.RemoteAddr[:strings.Index(req.RemoteAddr, ":")]
return
}
func remotePort(req *http.Request) (port string) {
if port = req.Header.Get(_httpHeaderRemoteIPPort); port != "" && port != "null" {
return
}
return
}

View File

@@ -0,0 +1,10 @@
### business/blademaster
##### Version 1.0.0
1. 添加基础模块与测试:
- Antispam
- Limiter
- Supervisor
- Degrade

View File

@@ -0,0 +1,8 @@
# Author
maojian
lintnaghui
caoguoliang
zhoujiahui
# Reviewer
maojian

View File

@@ -0,0 +1,12 @@
# See the OWNERS docs at https://go.k8s.io/owners
approvers:
- caoguoliang
- lintnaghui
- maojian
- zhoujiahui
reviewers:
- caoguoliang
- lintnaghui
- maojian
- zhoujiahui

View File

@@ -0,0 +1,26 @@
#### business/blademaster
> Out of Box Middleware
##### 项目简介
来自 bilibili 主站技术部的 blademaster middleware目前以下 middleware 已经 Ready for Production
- Antispam
- Limiter
- Supervisor
- Degrade
##### 项目特点
- 模块化设计,一个模块只干一件事
##### 编译环境
- **请只用 Golang v1.8.x 以上版本编译执行**
##### 依赖包
###### Limiter:
- [x/time/rate](golang.org/x/time/rate)

View File

@@ -0,0 +1,64 @@
package(default_visibility = ["//visibility:public"])
load(
"@io_bazel_rules_go//go:def.bzl",
"go_test",
"go_library",
)
go_test(
name = "go_default_test",
srcs = ["antispam_test.go"],
embed = [":go_default_library"],
rundir = ".",
tags = ["automanaged"],
deps = [
"//library/cache/redis:go_default_library",
"//library/container/pool:go_default_library",
"//library/net/http/blademaster:go_default_library",
"//library/time:go_default_library",
"//vendor/github.com/stretchr/testify/assert:go_default_library",
],
)
go_library(
name = "go_default_library",
srcs = ["antispam.go"],
importpath = "go-common/library/net/http/blademaster/middleware/antispam",
tags = ["automanaged"],
visibility = ["//visibility:public"],
deps = [
"//library/cache/redis:go_default_library",
"//library/ecode:go_default_library",
"//library/log:go_default_library",
"//library/net/http/blademaster:go_default_library",
"//vendor/github.com/pkg/errors:go_default_library",
],
)
go_test(
name = "go_default_xtest",
srcs = ["example_test.go"],
tags = ["automanaged"],
deps = [
"//library/cache/redis:go_default_library",
"//library/container/pool:go_default_library",
"//library/net/http/blademaster:go_default_library",
"//library/net/http/blademaster/middleware/antispam:go_default_library",
"//library/time:go_default_library",
],
)
filegroup(
name = "package-srcs",
srcs = glob(["**"]),
tags = ["automanaged"],
visibility = ["//visibility:private"],
)
filegroup(
name = "all-srcs",
srcs = [":package-srcs"],
tags = ["automanaged"],
visibility = ["//visibility:public"],
)

View File

@@ -0,0 +1,5 @@
### business/blademaster/antispam
##### Version 1.0.0
1. 完成基本功能与测试

View File

@@ -0,0 +1,6 @@
# Author
lintnaghui
caoguoliang
# Reviewer
maojian

View File

@@ -0,0 +1,9 @@
# See the OWNERS docs at https://go.k8s.io/owners
approvers:
- caoguoliang
- lintnaghui
reviewers:
- caoguoliang
- lintnaghui
- maojian

View File

@@ -0,0 +1,13 @@
#### business/blademaster/antispam
##### 项目简介
blademaster 的 antispam middleware主要用于限制用户的请求频率
##### 编译环境
- **请只用 Golang v1.8.x 以上版本编译执行**
##### 依赖包
- No other dependency

View File

@@ -0,0 +1,139 @@
package antispam
import (
"fmt"
"time"
"go-common/library/cache/redis"
"go-common/library/ecode"
"go-common/library/log"
bm "go-common/library/net/http/blademaster"
"github.com/pkg/errors"
)
const (
_prefixRate = "r_%d_%s_%d"
_prefixTotal = "t_%d_%s_%d"
// antispam
_defSecond = 1
_defHour = 1
)
// Antispam is a antispam instance.
type Antispam struct {
redis *redis.Pool
conf *Config
}
// Config antispam config.
type Config struct {
On bool // switch on/off
Second int // every N second allow N requests.
N int // one unit allow N requests.
Hour int // every N hour allow M requests.
M int // one winodw allow M requests.
Redis *redis.Config
}
func (c *Config) validate() error {
if c == nil {
return errors.New("antispam: empty config")
}
if c.Second < _defSecond {
return errors.New("antispam: invalid Second")
}
if c.Hour < _defHour {
return errors.New("antispam: invalid Hour")
}
return nil
}
// New new a antispam service.
func New(c *Config) (s *Antispam) {
if err := c.validate(); err != nil {
panic(err)
}
s = &Antispam{
redis: redis.NewPool(c.Redis),
}
s.Reload(c)
return s
}
// Reload reload antispam config.
func (s *Antispam) Reload(c *Config) {
if err := c.validate(); err != nil {
log.Error("Failed to reload antispam: %+v", err)
return
}
s.conf = c
}
// Rate antispam by user + path.
func (s *Antispam) Rate(c *bm.Context, second, count int) (err error) {
mid, ok := c.Get("mid")
if !ok {
return
}
curSecond := int(time.Now().Unix())
burst := curSecond - curSecond%second
key := rateKey(mid.(int64), c.Request.URL.Path, burst)
return s.antispam(c, key, second, count)
}
// Total antispam by user + path.
func (s *Antispam) Total(c *bm.Context, hour, count int) (err error) {
second := hour * 3600
mid, ok := c.Get("mid")
if !ok {
return
}
curHour := int(time.Now().Unix() / 3600)
burst := curHour - curHour%hour
key := totalKey(mid.(int64), c.Request.URL.Path, burst)
return s.antispam(c, key, second, count)
}
func (s *Antispam) antispam(c *bm.Context, key string, interval, count int) error {
conn := s.redis.Get(c)
defer conn.Close()
incred, err := redis.Int64(conn.Do("INCR", key))
if err != nil {
return nil
}
if incred == 1 {
conn.Do("EXPIRE", key, interval)
}
if incred > int64(count) {
return ecode.LimitExceed
}
return nil
}
func rateKey(mid int64, path string, burst int) string {
return fmt.Sprintf(_prefixRate, mid, path, burst)
}
func totalKey(mid int64, path string, burst int) string {
return fmt.Sprintf(_prefixTotal, mid, path, burst)
}
func (s *Antispam) ServeHTTP(ctx *bm.Context) {
if err := s.Rate(ctx, s.conf.Second, s.conf.N); err != nil {
ctx.JSON(nil, ecode.LimitExceed)
ctx.Abort()
return
}
if err := s.Total(ctx, s.conf.Hour, s.conf.M); err != nil {
ctx.JSON(nil, ecode.LimitExceed)
ctx.Abort()
return
}
}
// Handler is antispam handle.
func (s *Antispam) Handler() bm.HandlerFunc {
return s.ServeHTTP
}

View File

@@ -0,0 +1,108 @@
package antispam
import (
"context"
"io/ioutil"
"net/http"
"strconv"
"testing"
"time"
"github.com/stretchr/testify/assert"
"go-common/library/cache/redis"
"go-common/library/container/pool"
bm "go-common/library/net/http/blademaster"
xtime "go-common/library/time"
)
func TestAntiSpamHandler(t *testing.T) {
anti := New(
&Config{
On: true,
Second: 1,
N: 1,
Hour: 1,
M: 1,
Redis: &redis.Config{
Config: &pool.Config{
Active: 10,
Idle: 10,
IdleTimeout: xtime.Duration(time.Second * 60),
},
Name: "test",
Proto: "tcp",
Addr: "172.18.33.60:6889",
DialTimeout: xtime.Duration(time.Second),
ReadTimeout: xtime.Duration(time.Second),
WriteTimeout: xtime.Duration(time.Second),
},
},
)
engine := bm.New()
engine.UseFunc(func(c *bm.Context) {
mid, _ := strconv.ParseInt(c.Request.Form.Get("mid"), 10, 64)
c.Set("mid", mid)
c.Next()
})
engine.Use(anti.Handler())
engine.GET("/antispam", func(c *bm.Context) {
c.String(200, "pass")
})
go engine.Run(":18080")
time.Sleep(time.Millisecond * 50)
code, content, err := httpGet("http://127.0.0.1:18080/antispam?mid=11")
if err != nil {
t.Logf("http get failed, err:=%v", err)
t.FailNow()
}
if code != 200 || string(content) != "pass" {
t.Logf("request should pass by limiter, but blocked: %d, %v", code, content)
t.FailNow()
}
_, content, err = httpGet("http://127.0.0.1:18080/antispam?mid=11")
if err != nil {
t.Logf("http get failed, err:=%v", err)
t.FailNow()
}
if string(content) == "pass" {
t.Logf("request should block by limiter, but passed")
t.FailNow()
}
engine.Server().Shutdown(context.TODO())
}
func httpGet(url string) (code int, content []byte, err error) {
resp, err := http.Get(url)
if err != nil {
return
}
defer resp.Body.Close()
content, err = ioutil.ReadAll(resp.Body)
if err != nil {
return
}
code = resp.StatusCode
return
}
func TestConfigValidate(t *testing.T) {
var conf *Config
assert.Contains(t, conf.validate().Error(), "empty config")
conf = &Config{
Second: 0,
}
assert.Contains(t, conf.validate().Error(), "invalid Second")
conf = &Config{
Second: 1,
Hour: 0,
}
assert.Contains(t, conf.validate().Error(), "invalid Hour")
}

View File

@@ -0,0 +1,45 @@
package antispam_test
import (
"time"
"go-common/library/cache/redis"
"go-common/library/container/pool"
"go-common/library/net/http/blademaster"
"go-common/library/net/http/blademaster/middleware/antispam"
xtime "go-common/library/time"
)
// This example create a antispam middleware instance and attach to a blademaster engine,
// it will protect '/ping' API with specified policy.
// If anyone who requests this API more frequently than 1 req/second or 1 req/hour,
// a StatusServiceUnavailable error will be raised.
func Example() {
anti := antispam.New(&antispam.Config{
On: true,
Second: 1,
N: 1,
Hour: 1,
M: 1,
Redis: &redis.Config{
Config: &pool.Config{
Active: 10,
Idle: 10,
IdleTimeout: xtime.Duration(time.Second * 60),
},
Name: "test",
Proto: "tcp",
Addr: "172.18.33.60:6889",
DialTimeout: xtime.Duration(time.Second),
ReadTimeout: xtime.Duration(time.Second),
WriteTimeout: xtime.Duration(time.Second),
},
})
engine := blademaster.Default()
engine.Use(anti)
engine.GET("/ping", func(c *blademaster.Context) {
c.String(200, "%s", "pong")
})
engine.Run(":18080")
}

View File

@@ -0,0 +1,67 @@
package(default_visibility = ["//visibility:public"])
load(
"@io_bazel_rules_go//go:def.bzl",
"go_test",
"go_library",
)
go_test(
name = "go_default_test",
srcs = ["auth_test.go"],
embed = [":go_default_library"],
rundir = ".",
tags = ["automanaged"],
deps = [
"//library/ecode:go_default_library",
"//library/log:go_default_library",
"//library/net/http/blademaster:go_default_library",
"//library/net/metadata:go_default_library",
"//library/net/netutil/breaker:go_default_library",
"//library/net/rpc/warden:go_default_library",
"//library/time:go_default_library",
"//vendor/github.com/stretchr/testify/assert:go_default_library",
],
)
go_library(
name = "go_default_library",
srcs = ["auth.go"],
importpath = "go-common/library/net/http/blademaster/middleware/auth",
tags = ["automanaged"],
visibility = ["//visibility:public"],
deps = [
"//app/service/main/identify/api/grpc:go_default_library",
"//library/ecode:go_default_library",
"//library/net/http/blademaster:go_default_library",
"//library/net/metadata:go_default_library",
"//library/net/rpc/warden:go_default_library",
"//vendor/github.com/pkg/errors:go_default_library",
],
)
filegroup(
name = "package-srcs",
srcs = glob(["**"]),
tags = ["automanaged"],
visibility = ["//visibility:private"],
)
filegroup(
name = "all-srcs",
srcs = [":package-srcs"],
tags = ["automanaged"],
visibility = ["//visibility:public"],
)
go_test(
name = "go_default_xtest",
srcs = ["example_test.go"],
tags = ["automanaged"],
deps = [
"//library/net/http/blademaster:go_default_library",
"//library/net/http/blademaster/middleware/auth:go_default_library",
"//library/net/metadata:go_default_library",
"//library/net/rpc/warden:go_default_library",
],
)

View File

@@ -0,0 +1,13 @@
#### library/net/http/blademaster/middleware/auth
##### Version 1.0.2
1. 将认证使用的方法作为公开方法
##### Version 1.0.1
1. 成功后将 mid 加入 metadata
##### Version 1.0.0
1. 完全使用 identify-service 来认证用户

View File

@@ -0,0 +1,12 @@
# Owner
maojian
zhoujiahui
# Author
maojian
zhoujiahui
# Reviewer
maojian
haoguanwei
wanghuan01

View File

@@ -0,0 +1,10 @@
# See the OWNERS docs at https://go.k8s.io/owners
approvers:
- maojian
- zhoujiahui
reviewers:
- haoguanwei
- maojian
- wanghuan01
- zhoujiahui

View File

@@ -0,0 +1,13 @@
#### library/net/http/blademaster/middleware/auth
##### 项目简介
blademaster 的 authorization middleware主要用于设置路由的认证策略
##### 编译环境
- **请只用 Golang v1.10.x 以上版本编译执行**
##### 依赖包
- No other dependency

View File

@@ -0,0 +1,177 @@
package auth
import (
idtv1 "go-common/app/service/main/identify/api/grpc"
"go-common/library/ecode"
bm "go-common/library/net/http/blademaster"
"go-common/library/net/metadata"
"go-common/library/net/rpc/warden"
"github.com/pkg/errors"
)
// Config is the identify config model.
type Config struct {
Identify *warden.ClientConfig
// csrf switch.
DisableCSRF bool
}
// Auth is the authorization middleware
type Auth struct {
idtv1.IdentifyClient
conf *Config
}
// authFunc will return mid and error by given context
type authFunc func(*bm.Context) (int64, error)
var _defaultConf = &Config{
Identify: nil,
DisableCSRF: false,
}
// New is used to create an authorization middleware
func New(conf *Config) *Auth {
if conf == nil {
conf = _defaultConf
}
identify, err := idtv1.NewClient(conf.Identify)
if err != nil {
panic(errors.WithMessage(err, "Failed to dial identify service"))
}
auth := &Auth{
IdentifyClient: identify,
conf: conf,
}
return auth
}
// User is used to mark path as access required.
// If `access_key` is exist in request form, it will using mobile access policy.
// Otherwise to web access policy.
func (a *Auth) User(ctx *bm.Context) {
req := ctx.Request
if req.Form.Get("access_key") == "" {
a.UserWeb(ctx)
return
}
a.UserMobile(ctx)
}
// UserWeb is used to mark path as web access required.
func (a *Auth) UserWeb(ctx *bm.Context) {
a.midAuth(ctx, a.AuthCookie)
}
// UserMobile is used to mark path as mobile access required.
func (a *Auth) UserMobile(ctx *bm.Context) {
a.midAuth(ctx, a.AuthToken)
}
// Guest is used to mark path as guest policy.
// If `access_key` is exist in request form, it will using mobile access policy.
// Otherwise to web access policy.
func (a *Auth) Guest(ctx *bm.Context) {
req := ctx.Request
if req.Form.Get("access_key") == "" {
a.GuestWeb(ctx)
return
}
a.GuestMobile(ctx)
}
// GuestWeb is used to mark path as web guest policy.
func (a *Auth) GuestWeb(ctx *bm.Context) {
a.guestAuth(ctx, a.AuthCookie)
}
// GuestMobile is used to mark path as mobile guest policy.
func (a *Auth) GuestMobile(ctx *bm.Context) {
a.guestAuth(ctx, a.AuthToken)
}
// AuthToken is used to authorize request by token
func (a *Auth) AuthToken(ctx *bm.Context) (int64, error) {
req := ctx.Request
key := req.Form.Get("access_key")
if key == "" {
return 0, ecode.NoLogin
}
buvid := req.Header.Get("buvid")
reply, err := a.GetTokenInfo(ctx, &idtv1.GetTokenInfoReq{Token: key, Buvid: buvid})
if err != nil {
return 0, err
}
if !reply.IsLogin {
return 0, ecode.NoLogin
}
return reply.Mid, nil
}
// AuthCookie is used to authorize request by cookie
func (a *Auth) AuthCookie(ctx *bm.Context) (int64, error) {
req := ctx.Request
ssDaCk, _ := req.Cookie("SESSDATA")
if ssDaCk == nil {
return 0, ecode.NoLogin
}
cookie := req.Header.Get("Cookie")
reply, err := a.GetCookieInfo(ctx, &idtv1.GetCookieInfoReq{Cookie: cookie})
if err != nil {
return 0, err
}
if !reply.IsLogin {
return 0, ecode.NoLogin
}
// check csrf
clientCsrf := req.FormValue("csrf")
if a.conf != nil && !a.conf.DisableCSRF && req.Method == "POST" {
if clientCsrf != reply.Csrf {
return 0, ecode.CsrfNotMatchErr
}
}
return reply.Mid, nil
}
func (a *Auth) midAuth(ctx *bm.Context, auth authFunc) {
mid, err := auth(ctx)
if err != nil {
ctx.JSON(nil, err)
ctx.Abort()
return
}
setMid(ctx, mid)
}
func (a *Auth) guestAuth(ctx *bm.Context, auth authFunc) {
mid, err := auth(ctx)
// no error happened and mid is valid
if err == nil && mid > 0 {
setMid(ctx, mid)
return
}
ec := ecode.Cause(err)
if ec.Equal(ecode.CsrfNotMatchErr) {
ctx.JSON(nil, ec)
ctx.Abort()
return
}
}
// set mid into context
// NOTE: This method is not thread safe.
func setMid(ctx *bm.Context, mid int64) {
ctx.Set("mid", mid)
if md, ok := metadata.FromContext(ctx); ok {
md[metadata.Mid] = mid
return
}
}

View File

@@ -0,0 +1,341 @@
package auth
import (
"bytes"
"context"
"fmt"
"mime/multipart"
"net/http"
"net/url"
"testing"
"time"
"go-common/library/ecode"
"go-common/library/log"
bm "go-common/library/net/http/blademaster"
"go-common/library/net/metadata"
"go-common/library/net/netutil/breaker"
"go-common/library/net/rpc/warden"
xtime "go-common/library/time"
"github.com/stretchr/testify/assert"
)
const (
_testUID = "2231365"
)
type Response struct {
Code int `json:"code"`
Data string `json:"data"`
}
func init() {
log.Init(&log.Config{
Stdout: true,
})
}
func client() *bm.Client {
return bm.NewClient(&bm.ClientConfig{
App: &bm.App{
Key: "53e2fa226f5ad348",
Secret: "3cf6bd1b0ff671021da5f424fea4b04a",
},
Dial: xtime.Duration(time.Second),
Timeout: xtime.Duration(time.Second),
KeepAlive: xtime.Duration(time.Second * 10),
Breaker: &breaker.Config{
Window: xtime.Duration(time.Second),
Sleep: xtime.Duration(time.Millisecond * 100),
Bucket: 10,
Ratio: 0.5,
Request: 100,
},
})
}
func create() *Auth {
return New(&Config{
Identify: &warden.ClientConfig{},
DisableCSRF: false,
})
}
func engine() *bm.Engine {
e := bm.NewServer(nil)
authn := create()
e.GET("/user", authn.User, func(ctx *bm.Context) {
mid, _ := ctx.Get("mid")
ctx.JSON(fmt.Sprintf("%d", mid), nil)
})
e.GET("/metadata/user", authn.User, func(ctx *bm.Context) {
mid := metadata.Value(ctx, metadata.Mid)
ctx.JSON(fmt.Sprintf("%d", mid.(int64)), nil)
})
e.GET("/mobile", authn.UserMobile, func(ctx *bm.Context) {
mid, _ := ctx.Get("mid")
ctx.JSON(fmt.Sprintf("%d", mid), nil)
})
e.GET("/metadata/mobile", authn.UserMobile, func(ctx *bm.Context) {
mid := metadata.Value(ctx, metadata.Mid)
ctx.JSON(fmt.Sprintf("%d", mid.(int64)), nil)
})
e.GET("/web", authn.UserWeb, func(ctx *bm.Context) {
mid, _ := ctx.Get("mid")
ctx.JSON(fmt.Sprintf("%d", mid), nil)
})
e.GET("/guest", authn.Guest, func(ctx *bm.Context) {
var (
mid int64
)
if _mid, ok := ctx.Get("mid"); ok {
mid, _ = _mid.(int64)
}
ctx.JSON(fmt.Sprintf("%d", mid), nil)
})
e.GET("/guest/web", authn.GuestWeb, func(ctx *bm.Context) {
var (
mid int64
)
if _mid, ok := ctx.Get("mid"); ok {
mid, _ = _mid.(int64)
}
ctx.JSON(fmt.Sprintf("%d", mid), nil)
})
e.GET("/guest/mobile", authn.GuestMobile, func(ctx *bm.Context) {
var (
mid int64
)
if _mid, ok := ctx.Get("mid"); ok {
mid, _ = _mid.(int64)
}
ctx.JSON(fmt.Sprintf("%d", mid), nil)
})
e.POST("/guest/csrf", authn.Guest, func(ctx *bm.Context) {
var (
mid int64
)
if _mid, ok := ctx.Get("mid"); ok {
mid, _ = _mid.(int64)
}
ctx.JSON(fmt.Sprintf("%d", mid), nil)
})
return e
}
func TestFromNilConfig(t *testing.T) {
New(nil)
}
func TestIdentifyHandler(t *testing.T) {
e := engine()
go e.Run(":18080")
time.Sleep(time.Second)
// test cases
testWebUser(t, "/user")
testWebUser(t, "/metadata/user")
testWebUser(t, "/web")
testWebUser(t, "/guest")
testWebUser(t, "/guest/web")
testWebUserFailed(t, "/user")
testWebUserFailed(t, "/web")
testMobileUser(t, "/user")
testMobileUser(t, "/mobile")
testMobileUser(t, "/metadata/mobile")
testMobileUser(t, "/guest")
testMobileUser(t, "/guest/mobile")
testMobileUserFailed(t, "/user")
testMobileUserFailed(t, "/mobile")
testGuest(t, "/guest")
testGuestCSRF(t, "/guest/csrf")
testGuestCSRFFailed(t, "/guest/csrf")
testMultipartCSRF(t, "/guest/csrf")
if err := e.Server().Shutdown(context.TODO()); err != nil {
t.Logf("Failed to shutdown bm engine: %v", err)
}
}
func testWebUser(t *testing.T, path string) {
res := Response{}
query := url.Values{}
cli := client()
req, err := cli.NewRequest(http.MethodGet, "http://127.0.0.1:18080/"+path, "", query)
assert.NoError(t, err)
req.AddCookie(&http.Cookie{
Name: "DedeUserID",
Value: _testUID,
})
req.AddCookie(&http.Cookie{
Name: "DedeUserID__ckMd5",
Value: "36976f7a5cb6e4a6",
})
req.AddCookie(&http.Cookie{
Name: "SESSDATA",
Value: "7bf20cf0%2C1540627371%2C8ec39f0c",
})
err = cli.Do(context.TODO(), req, &res)
assert.NoError(t, err)
assert.Equal(t, 0, res.Code)
assert.Equal(t, _testUID, res.Data)
}
func testMobileUser(t *testing.T, path string) {
res := Response{}
query := url.Values{}
query.Set("access_key", "cdbd166be6673a5a4f6fbcdd88569edf")
cli := client()
req, err := cli.NewRequest(http.MethodGet, "http://127.0.0.1:18080"+path, "", query)
assert.NoError(t, err)
err = cli.Do(context.TODO(), req, &res)
assert.NoError(t, err)
assert.Equal(t, 0, res.Code)
assert.Equal(t, _testUID, res.Data)
}
func testWebUserFailed(t *testing.T, path string) {
res := Response{}
query := url.Values{}
cli := client()
req, err := cli.NewRequest(http.MethodGet, "http://127.0.0.1:18080/"+path, "", query)
assert.NoError(t, err)
req.AddCookie(&http.Cookie{
Name: "DedeUserID",
Value: _testUID,
})
req.AddCookie(&http.Cookie{
Name: "DedeUserID__ckMd5",
Value: "53c4b106fb4462f1",
})
req.AddCookie(&http.Cookie{
Name: "SESSDATA",
Value: "6eeda532%2C1515837495%2C5a6baa4e",
})
err = cli.Do(context.TODO(), req, &res)
assert.NoError(t, err)
assert.Equal(t, ecode.NoLogin.Code(), res.Code)
assert.Empty(t, res.Data)
}
func testMobileUserFailed(t *testing.T, path string) {
res := Response{}
query := url.Values{}
query.Set("access_key", "5dce488c2ff8d62d7b131da40ae18729")
cli := client()
req, err := cli.NewRequest(http.MethodGet, "http://127.0.0.1:18080"+path, "", query)
assert.NoError(t, err)
err = cli.Do(context.TODO(), req, &res)
assert.NoError(t, err)
assert.Equal(t, ecode.NoLogin.Code(), res.Code)
assert.Empty(t, res.Data)
}
func testGuest(t *testing.T, path string) {
res := Response{}
cli := client()
req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1:18080"+path, nil)
assert.NoError(t, err)
err = cli.Do(context.TODO(), req, &res)
assert.NoError(t, err)
assert.Equal(t, 0, res.Code)
assert.Equal(t, "0", res.Data)
}
func testGuestCSRF(t *testing.T, path string) {
res := Response{}
param := url.Values{}
param.Set("csrf", "c1524bbf3aa5a1996ff7b1f29a09e796")
cli := client()
req, err := cli.NewRequest(http.MethodPost, "http://127.0.0.1:18080"+path, "", param)
assert.NoError(t, err)
req.AddCookie(&http.Cookie{
Name: "DedeUserID",
Value: _testUID,
})
req.AddCookie(&http.Cookie{
Name: "DedeUserID__ckMd5",
Value: "36976f7a5cb6e4a6",
})
req.AddCookie(&http.Cookie{
Name: "SESSDATA",
Value: "7bf20cf0%2C1540627371%2C8ec39f0c",
})
err = cli.Do(context.TODO(), req, &res)
assert.NoError(t, err)
assert.Equal(t, 0, res.Code)
assert.Equal(t, _testUID, res.Data)
}
func testGuestCSRFFailed(t *testing.T, path string) {
res := Response{}
param := url.Values{}
param.Set("csrf", "invalid-csrf-token")
cli := client()
req, err := cli.NewRequest(http.MethodPost, "http://127.0.0.1:18080"+path, "", param)
assert.NoError(t, err)
req.AddCookie(&http.Cookie{
Name: "DedeUserID",
Value: _testUID,
})
req.AddCookie(&http.Cookie{
Name: "DedeUserID__ckMd5",
Value: "36976f7a5cb6e4a6",
})
req.AddCookie(&http.Cookie{
Name: "SESSDATA",
Value: "7bf20cf0%2C1540627371%2C8ec39f0c",
})
err = cli.Do(context.TODO(), req, &res)
assert.NoError(t, err)
assert.Equal(t, ecode.CsrfNotMatchErr.Code(), res.Code)
assert.Empty(t, res.Data)
}
func testMultipartCSRF(t *testing.T, path string) {
res := Response{}
body := &bytes.Buffer{}
writer := multipart.NewWriter(body)
writer.WriteField("csrf", "c1524bbf3aa5a1996ff7b1f29a09e796")
writer.Close()
req, err := http.NewRequest("POST", "http://127.0.0.1:18080"+path, body)
assert.NoError(t, err)
req.Header.Set("Content-Type", writer.FormDataContentType())
cli := client()
req.AddCookie(&http.Cookie{
Name: "DedeUserID",
Value: _testUID,
})
req.AddCookie(&http.Cookie{
Name: "DedeUserID__ckMd5",
Value: "36976f7a5cb6e4a6",
})
req.AddCookie(&http.Cookie{
Name: "SESSDATA",
Value: "7bf20cf0%2C1540627371%2C8ec39f0c",
})
err = cli.Do(context.TODO(), req, &res)
assert.NoError(t, err)
assert.Equal(t, 0, res.Code)
assert.Equal(t, _testUID, res.Data)
}

View File

@@ -0,0 +1,45 @@
package auth_test
import (
"fmt"
bm "go-common/library/net/http/blademaster"
"go-common/library/net/http/blademaster/middleware/auth"
"go-common/library/net/metadata"
"go-common/library/net/rpc/warden"
)
// This example create a identify middleware instance and attach to several path,
// it will validate request by specified policy and put extra information into context. e.g., `mid`.
// It provides additional handler functions to provide the identification for your business handler.
func Example() {
authn := auth.New(&auth.Config{
Identify: &warden.ClientConfig{},
DisableCSRF: false,
})
e := bm.DefaultServer(nil)
// mark `/user` path as User policy
e.GET("/user", authn.User, func(ctx *bm.Context) {
mid := metadata.Int64(ctx, metadata.Mid)
ctx.JSON(fmt.Sprintf("%d", mid), nil)
})
// mark `/mobile` path as UserMobile policy
e.GET("/mobile", authn.UserMobile, func(ctx *bm.Context) {
mid := metadata.Int64(ctx, metadata.Mid)
ctx.JSON(fmt.Sprintf("%d", mid), nil)
})
// mark `/web` path as UserWeb policy
e.GET("/web", authn.UserWeb, func(ctx *bm.Context) {
mid := metadata.Int64(ctx, metadata.Mid)
ctx.JSON(fmt.Sprintf("%d", mid), nil)
})
// mark `/guest` path as Guest policy
e.GET("/guest", authn.Guest, func(ctx *bm.Context) {
mid := metadata.Int64(ctx, metadata.Mid)
ctx.JSON(fmt.Sprintf("%d", mid), nil)
})
e.Run(":18080")
}

View File

@@ -0,0 +1,101 @@
load(
"@io_bazel_rules_go//proto:def.bzl",
"go_proto_library",
)
package(default_visibility = ["//visibility:public"])
load(
"@io_bazel_rules_go//go:def.bzl",
"go_test",
"go_library",
)
proto_library(
name = "cache_proto",
srcs = ["page.proto"],
tags = ["automanaged"],
deps = ["@gogo_special_proto//github.com/gogo/protobuf/gogoproto"],
)
go_proto_library(
name = "cache_go_proto",
compilers = ["@io_bazel_rules_go//proto:gogofast_proto"],
importpath = "go-common/library/net/http/blademaster/middleware/cache",
proto = ":cache_proto",
tags = ["automanaged"],
deps = ["@com_github_gogo_protobuf//gogoproto:go_default_library"],
)
go_test(
name = "go_default_test",
srcs = ["cache_test.go"],
embed = [":go_default_library"],
rundir = ".",
tags = ["automanaged"],
deps = [
"//library/cache/memcache:go_default_library",
"//library/container/pool:go_default_library",
"//library/ecode:go_default_library",
"//library/log:go_default_library",
"//library/net/http/blademaster:go_default_library",
"//library/net/http/blademaster/middleware/cache/store:go_default_library",
"//library/time:go_default_library",
"//vendor/github.com/stretchr/testify/assert:go_default_library",
],
)
go_library(
name = "go_default_library",
srcs = [
"cache.go",
"control.go",
"degrade.go",
"page.go",
],
embed = [":cache_go_proto"],
importpath = "go-common/library/net/http/blademaster/middleware/cache",
tags = ["automanaged"],
visibility = ["//visibility:public"],
deps = [
"//library/ecode:go_default_library",
"//library/log:go_default_library",
"//library/net/http/blademaster:go_default_library",
"//library/net/http/blademaster/middleware/cache/store:go_default_library",
"@com_github_gogo_protobuf//gogoproto:go_default_library",
"@com_github_gogo_protobuf//proto:go_default_library",
],
)
go_test(
name = "go_default_xtest",
srcs = ["example_test.go"],
tags = ["automanaged"],
deps = [
"//library/cache/memcache:go_default_library",
"//library/container/pool:go_default_library",
"//library/ecode:go_default_library",
"//library/net/http/blademaster:go_default_library",
"//library/net/http/blademaster/middleware/cache:go_default_library",
"//library/net/http/blademaster/middleware/cache/store:go_default_library",
"//library/time:go_default_library",
"//vendor/github.com/pkg/errors:go_default_library",
],
)
filegroup(
name = "package-srcs",
srcs = glob(["**"]),
tags = ["automanaged"],
visibility = ["//visibility:private"],
)
filegroup(
name = "all-srcs",
srcs = [
":package-srcs",
"//library/net/http/blademaster/middleware/cache/store:all-srcs",
],
tags = ["automanaged"],
visibility = ["//visibility:public"],
)

View File

@@ -0,0 +1,10 @@
### business/blademaster/cache
##### Version 1.0.1
1. 添加 Control 策略(目前仅通过 Expires 和 Cache-Control 实现客户端缓存)
##### Version 1.0.0
1. 完成基本功能与测试
2. 完成 Degrade 与 PageCache 逻辑

View File

@@ -0,0 +1,5 @@
# Author
zhoujiahui
# Reviewer
maojian

View File

@@ -0,0 +1,7 @@
# See the OWNERS docs at https://go.k8s.io/owners
approvers:
- zhoujiahui
reviewers:
- maojian
- zhoujiahui

View File

@@ -0,0 +1,13 @@
#### business/blademaster/cache
##### 项目简介
blademaster 的通用 cache 模块,一般直接用于缓存返回的 response
##### 编译环境
- **请只用 Golang v1.8.x 以上版本编译执行**
##### 依赖包
- No other dependency

View File

@@ -0,0 +1,38 @@
package cache
import (
bm "go-common/library/net/http/blademaster"
"go-common/library/net/http/blademaster/middleware/cache/store"
)
// Cache is the abstract struct for any cache impl
type Cache struct {
store store.Store
}
// Filter is used to check is cache required for every request
type Filter func(*bm.Context) bool
// Policy is used to abstract different cache policy
type Policy interface {
Key(*bm.Context) string
Handler(store.Store) bm.HandlerFunc
}
// New will create a new Cache struct
func New(store store.Store) *Cache {
c := &Cache{
store: store,
}
return c
}
// Cache is used to mark path as customized cache policy
func (c *Cache) Cache(policy Policy, filter Filter) bm.HandlerFunc {
return func(ctx *bm.Context) {
if filter != nil && !filter(ctx) {
return
}
policy.Handler(c.store)(ctx)
}
}

View File

@@ -0,0 +1,353 @@
package cache
import (
"context"
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"os"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
"go-common/library/cache/memcache"
"go-common/library/container/pool"
"go-common/library/ecode"
"go-common/library/log"
bm "go-common/library/net/http/blademaster"
"go-common/library/net/http/blademaster/middleware/cache/store"
xtime "go-common/library/time"
"github.com/stretchr/testify/assert"
)
const (
SockAddr = "127.0.0.1:18080"
McSockAddr = "172.16.33.54:11211"
)
func uri(base, path string) string {
return fmt.Sprintf("%s://%s%s", "http", base, path)
}
func init() {
log.Init(nil)
}
func newMemcache() (*Cache, func()) {
s := store.NewMemcache(&memcache.Config{
Config: &pool.Config{
Active: 10,
Idle: 2,
IdleTimeout: xtime.Duration(time.Second),
},
Name: "test",
Proto: "tcp",
Addr: McSockAddr,
DialTimeout: xtime.Duration(time.Second),
ReadTimeout: xtime.Duration(time.Second),
WriteTimeout: xtime.Duration(time.Second),
})
return New(s), func() {}
}
func newFile() (*Cache, func()) {
path, err := ioutil.TempDir("", "cache-test")
if err != nil {
panic("Failed to create cache directory")
}
s := store.NewFile(&store.FileConfig{
RootDir: path,
})
remove := func() {
os.RemoveAll(path)
}
return New(s), remove
}
func TestPage(t *testing.T) {
memcache, remove1 := newMemcache()
filestore, remove2 := newFile()
defer func() {
remove1()
remove2()
}()
t.Run("Memcache Store", pageCase(memcache, true))
t.Run("File Store", pageCase(filestore, false))
}
func TestControl(t *testing.T) {
memcache, remove1 := newMemcache()
filestore, remove2 := newFile()
defer func() {
remove1()
remove2()
}()
t.Run("Memcache Store", controlCase(memcache, true))
t.Run("File Store", controlCase(filestore, false))
}
func TestPageCacheMultiWrite(t *testing.T) {
memcache, remove1 := newMemcache()
filestore, remove2 := newFile()
defer func() {
remove1()
remove2()
}()
t.Run("Memcache Store", pageMultiWriteCase(memcache))
t.Run("File Store", pageMultiWriteCase(filestore))
}
func TestDegrade(t *testing.T) {
memcache, remove1 := newMemcache()
filestore, remove2 := newFile()
defer func() {
remove1()
remove2()
}()
t.Run("Memcache Store", degradeCase(memcache))
t.Run("File Store", degradeCase(filestore))
}
func pageCase(cache *Cache, testExpire bool) func(t *testing.T) {
return func(t *testing.T) {
expire := int32(3)
pc := NewPage(expire)
engine := bm.Default()
engine.GET("/page-cache", cache.Cache(pc, nil), func(ctx *bm.Context) {
ctx.Writer.Header().Set("X-Hello", "World")
ctx.String(203, "%s\n", time.Now().String())
})
go engine.Run(SockAddr)
defer func() {
engine.Server().Shutdown(context.Background())
}()
time.Sleep(time.Second)
code1, content1, headers1, err1 := httpGet(uri(SockAddr, "/page-cache"))
code2, content2, headers2, err2 := httpGet(uri(SockAddr, "/page-cache"))
assert.Nil(t, err1)
assert.Nil(t, err2)
assert.Equal(t, code1, 203)
assert.Equal(t, code2, 203)
assert.NotNil(t, content1)
assert.NotNil(t, content2)
assert.Equal(t, headers1["X-Hello"], []string{"World"})
assert.Equal(t, headers2["X-Hello"], []string{"World"})
assert.Equal(t, string(content1), string(content2))
if !testExpire {
return
}
// test if the last caching is expired
t.Logf("Waiting %d seconds for caching expire test", expire+1)
time.Sleep(time.Second * time.Duration(expire+1))
_, content3, _, err3 := httpGet(uri(SockAddr, "/page-cache"))
_, content4, _, err4 := httpGet(uri(SockAddr, "/page-cache"))
assert.Nil(t, err3)
assert.Nil(t, err4)
assert.NotNil(t, content1)
assert.NotNil(t, content2)
assert.NotEqual(t, string(content1), string(content3))
assert.Equal(t, string(content3), string(content4))
}
}
func pageMultiWriteCase(cache *Cache) func(t *testing.T) {
return func(t *testing.T) {
chunks := []string{
"Hello",
"World",
"Hello",
"World",
"Hello",
"World",
"Hello",
"World",
}
pc := NewPage(3)
engine := bm.Default()
engine.GET("/page-cache-write", cache.Cache(pc, nil), func(ctx *bm.Context) {
ctx.Writer.Header().Set("X-Hello", "World")
ctx.Writer.WriteHeader(203)
for _, chunk := range chunks {
ctx.Writer.Write([]byte(chunk))
}
})
go engine.Run(SockAddr)
defer func() {
engine.Server().Shutdown(context.Background())
}()
time.Sleep(time.Second)
code1, content1, headers1, err1 := httpGet(uri(SockAddr, "/page-cache-write"))
code2, content2, headers2, err2 := httpGet(uri(SockAddr, "/page-cache-write"))
assert.Nil(t, err1)
assert.Nil(t, err2)
assert.Equal(t, code1, 203)
assert.Equal(t, code2, 203)
assert.NotNil(t, content1)
assert.NotNil(t, content2)
assert.Equal(t, headers1["X-Hello"], []string{"World"})
assert.Equal(t, headers2["X-Hello"], []string{"World"})
assert.Equal(t, strings.Join(chunks, ""), string(content1))
assert.Equal(t, strings.Join(chunks, ""), string(content2))
assert.Equal(t, string(content1), string(content2))
}
}
func degradeCase(cache *Cache) func(t *testing.T) {
return func(t *testing.T) {
wg := sync.WaitGroup{}
i := int32(0)
degrade := NewDegrader(10)
engine := bm.Default()
engine.GET("/scheduled/error", cache.Cache(degrade.Args("name", "age"), nil), func(c *bm.Context) {
code := atomic.AddInt32(&i, 1)
if code == 5 {
c.JSON("succeed", nil)
return
}
if code%2 == 0 {
c.JSON("", ecode.Degrade)
return
}
c.JSON(fmt.Sprintf("Code: %d", code), ecode.Int(int(code)))
})
wg.Add(1)
go func() {
engine.Run(":18080")
wg.Done()
}()
defer func() {
engine.Server().Shutdown(context.TODO())
wg.Wait()
}()
time.Sleep(time.Second)
for index := 1; index < 10; index++ {
_, content, _, _ := httpGet(uri(SockAddr, "/scheduled/error?name=degrader&age=26"))
t.Log(index, string(content))
var res struct {
Data string `json:"data"`
}
err := json.Unmarshal(content, &res)
assert.Nil(t, err)
if index == 5 {
// ensure response is write to cache
time.Sleep(time.Second)
}
if index > 5 && index%2 == 0 {
if res.Data != "succeed" {
t.Fatalf("Failed to degrade at index: %d", index)
} else {
t.Logf("This request is degraded at index: %d", index)
}
}
}
}
}
func controlCase(cache *Cache, testExpire bool) func(t *testing.T) {
return func(t *testing.T) {
wg := sync.WaitGroup{}
i := int32(0)
expire := int32(30)
control := NewControl(expire)
filter := func(ctx *bm.Context) bool {
if ctx.Request.Form.Get("cache") == "false" {
return false
}
return true
}
engine := bm.Default()
engine.GET("/large/response", cache.Cache(control, filter), func(c *bm.Context) {
c.JSON(map[string]interface{}{
"index": atomic.AddInt32(&i, 1),
"Hello0": "World",
"Hello1": "World",
"Hello2": "World",
"Hello3": "World",
"Hello4": "World",
"Hello5": "World",
"Hello6": "World",
"Hello7": "World",
"Hello8": "World",
}, nil)
})
engine.GET("/large/response/error", cache.Cache(control, filter), func(c *bm.Context) {
c.JSON(nil, ecode.RequestErr)
})
wg.Add(1)
go func() {
engine.Run(":18080")
wg.Done()
}()
defer func() {
engine.Server().Shutdown(context.TODO())
wg.Wait()
}()
time.Sleep(time.Second)
code, content, headers, err := httpGet(uri(SockAddr, "/large/response?name=hello&age=1"))
assert.NoError(t, err)
assert.Equal(t, 200, code)
assert.NotEmpty(t, content)
assert.Equal(t, "max-age=30", headers.Get("Cache-Control"))
exp, err := http.ParseTime(headers.Get("Expires"))
assert.NoError(t, err)
assert.InDelta(t, 30, exp.Unix()-time.Now().Unix(), 5)
code, content, headers, err = httpGet(uri(SockAddr, "/large/response/error?name=hello&age=1&cache=false"))
assert.NoError(t, err)
assert.Equal(t, 200, code)
assert.NotEmpty(t, content)
assert.Empty(t, headers.Get("Expires"))
assert.Empty(t, headers.Get("Cache-Control"))
code, content, headers, err = httpGet(uri(SockAddr, "/large/response/error?name=hello&age=1"))
assert.NoError(t, err)
assert.Equal(t, 200, code)
assert.NotEmpty(t, content)
assert.Empty(t, headers.Get("Expires"))
assert.Empty(t, headers.Get("Cache-Control"))
}
}
func httpGet(url string) (code int, content []byte, headers http.Header, err error) {
resp, err := http.Get(url)
if err != nil {
return
}
defer resp.Body.Close()
if content, err = ioutil.ReadAll(resp.Body); err != nil {
return
}
code = resp.StatusCode
headers = resp.Header
return
}

View File

@@ -0,0 +1,74 @@
package cache
import (
fmt "fmt"
"net/http"
"sync"
"time"
bm "go-common/library/net/http/blademaster"
"go-common/library/net/http/blademaster/middleware/cache/store"
)
const (
_maxMaxAge = 60 * 5 // 5 minutes
)
// Control is used to work as client side cache orchestrator
type Control struct {
MaxAge int32
pool sync.Pool
}
type controlWriter struct {
*Control
ctx *bm.Context
response http.ResponseWriter
}
var _ http.ResponseWriter = &controlWriter{}
// NewControl will create a new control cache struct
func NewControl(maxAge int32) *Control {
if maxAge > _maxMaxAge {
panic("MaxAge should be less than 300 seconds")
}
ctl := &Control{
MaxAge: maxAge,
}
ctl.pool.New = func() interface{} {
return &controlWriter{}
}
return ctl
}
// Key method is not needed in this situation
func (ctl *Control) Key(ctx *bm.Context) string { return "" }
// Handler is used to execute cache service
func (ctl *Control) Handler(_ store.Store) bm.HandlerFunc {
return func(ctx *bm.Context) {
writer := ctl.pool.Get().(*controlWriter)
writer.Control = ctl
writer.ctx = ctx
writer.response = ctx.Writer
ctx.Writer = writer
ctx.Next()
ctl.pool.Put(writer)
}
}
func (w *controlWriter) Header() http.Header { return w.response.Header() }
func (w *controlWriter) Write(data []byte) (size int, err error) { return w.response.Write(data) }
func (w *controlWriter) WriteHeader(code int) {
// do not inject header if this is an error response
if w.ctx.Error == nil {
headers := w.Header()
headers.Set("Expires", time.Now().UTC().Add(time.Duration(w.MaxAge)*time.Second).Format(http.TimeFormat))
headers.Set("Cache-Control", fmt.Sprintf("max-age=%d", w.MaxAge))
}
w.response.WriteHeader(code)
}

View File

@@ -0,0 +1,219 @@
package cache
import (
"context"
"crypto/md5"
"fmt"
"net/http"
"strings"
"sync"
"sync/atomic"
"time"
"go-common/library/ecode"
"go-common/library/log"
bm "go-common/library/net/http/blademaster"
"go-common/library/net/http/blademaster/middleware/cache/store"
)
const (
_degradeInterval = 60 * 10
_degradePrefix = "bm.degrade"
)
var (
_degradeBytes = []byte(fmt.Sprintf("{\"code\":%d, \"message\":\"\"}", ecode.Degrade))
)
// Degrader is the common degrader instance.
type Degrader struct {
lock sync.RWMutex
urls map[string]*state
expire int32
ch chan *result
pool sync.Pool // degradeWriter pool
}
// argsDegrader means the degrade will happened by args policy
type argsDegrader struct {
*Degrader
args []string
}
type degradeWriter struct {
*Degrader
ctx *bm.Context
response http.ResponseWriter
store store.Store
key string
state *state
}
type state struct {
// FIXME(zhoujiahui): using transient map to avoid potential memory leak?
// record last cached time
sync.RWMutex
gens map[string]*int64
}
type result struct {
key string
value []byte
store store.Store
}
var _ http.ResponseWriter = &degradeWriter{}
var _ Policy = &argsDegrader{}
// NewDegrader will create a new degrade struct
func NewDegrader(expire int32) (d *Degrader) {
d = &Degrader{
urls: make(map[string]*state),
ch: make(chan *result, 1024),
expire: expire,
}
d.pool.New = func() interface{} {
return &degradeWriter{
Degrader: d,
}
}
go d.degradeproc()
return
}
func (d *Degrader) degradeproc() {
for {
r := <-d.ch
if err := r.store.Set(context.Background(), r.key, r.value, d.expire); err != nil {
log.Error("store write key(%s) error(%v)", r.key, err)
}
}
}
// Args means this path will be degrade by specified args
func (d *Degrader) Args(args ...string) Policy {
return &argsDegrader{
Degrader: d,
args: args,
}
}
func (d *Degrader) state(path string) *state {
d.lock.RLock()
s, ok := d.urls[path]
d.lock.RUnlock()
if !ok {
s = &state{
gens: make(map[string]*int64),
}
d.lock.Lock()
d.urls[path] = s
d.lock.Unlock()
}
return s
}
// Key is used to identify response cache key in most key-value store
func (ad *argsDegrader) Key(ctx *bm.Context) string {
req := ctx.Request
path := req.URL.Path
params := req.Form
vs := make([]string, 0, len(ad.args))
for _, arg := range ad.args {
vs = append(vs, params.Get(arg))
}
return fmt.Sprintf("%s:%s_%x", _degradePrefix, strings.Replace(path, "/", "_", -1), md5.Sum([]byte(strings.Join(vs, "-"))))
}
// Handler is used to execute degrade service
func (ad *argsDegrader) Handler(store store.Store) bm.HandlerFunc {
return func(ctx *bm.Context) {
req := ctx.Request
path := req.URL.Path
writer := ad.pool.Get().(*degradeWriter)
writer.response = ctx.Writer
writer.ctx = ctx
writer.store = store
writer.state = ad.state(path)
writer.key = ad.Key(ctx)
ctx.Writer = writer // replace to degrade writer
ctx.Next()
ad.pool.Put(writer)
}
}
func (w *degradeWriter) Header() http.Header { return w.response.Header() }
func (w *degradeWriter) WriteHeader(code int) { w.response.WriteHeader(code) }
func (w *degradeWriter) Write(data []byte) (size int, err error) {
e := w.ctx.Error
// if an degrade error code is raised from upstream,
// degrade this request directly
if e != nil {
if ec := ecode.Cause(e); ec.Code() == ecode.Degrade.Code() {
return w.write()
}
}
// write origin response
if size, err = w.response.Write(data); err != nil {
return
}
// error raised, this is a unsuccessful response
if e != nil {
return
}
// is required to cache
if !w.state.required(w.key) {
return
}
// async cache succeeded response for further degradation
select {
case w.ch <- &result{key: w.key, value: data, store: w.store}:
default:
}
return
}
func (w *degradeWriter) write() (int, error) {
data, err := w.store.Get(w.ctx, w.key)
if err != nil || len(data) == 0 {
// FIXME(zhoujiahui): The default response data should be respect to render type or content-type header
data = _degradeBytes
}
return w.response.Write(data)
}
// check is required to cache response
// it depends on last cache time and _degradeInterval
func (st *state) required(key string) bool {
now := time.Now().Unix()
st.RLock()
pLast, ok := st.gens[key]
st.RUnlock()
if !ok {
st.Lock()
pLast = new(int64)
st.gens[key] = pLast
st.Unlock()
}
last := atomic.LoadInt64(pLast)
if now-last < _degradeInterval {
return false
}
return atomic.CompareAndSwapInt64(pLast, last, now)
}

View File

@@ -0,0 +1,77 @@
package cache_test
import (
"time"
"go-common/library/cache/memcache"
"go-common/library/container/pool"
"go-common/library/ecode"
"go-common/library/net/http/blademaster"
"go-common/library/net/http/blademaster/middleware/cache"
"go-common/library/net/http/blademaster/middleware/cache/store"
xtime "go-common/library/time"
"github.com/pkg/errors"
)
// This example create a cache middleware instance and two cache policy,
// then attach them to the specified path.
//
// The `PageCache` policy will attempt to cache the whole response by URI.
// It usually used to cache the common response.
//
// The `Degrader` policy usually used to prevent the API totaly unavailable if any disaster is happen.
// A succeeded response will be cached per 600s.
// The cache key is generated by specified args and its values.
// You can using file or memcache as cache backend for degradation currently.
//
// The `Cache` policy is used to work with multilevel HTTP caching architecture.
// This will cause client side response caching.
// We only support weak validator with `ETag` header currently.
func Example() {
mc := store.NewMemcache(&memcache.Config{
Config: &pool.Config{
Active: 10,
Idle: 2,
IdleTimeout: xtime.Duration(time.Second),
},
Name: "test",
Proto: "tcp",
Addr: "172.16.33.54:11211",
DialTimeout: xtime.Duration(time.Second),
ReadTimeout: xtime.Duration(time.Second),
WriteTimeout: xtime.Duration(time.Second),
})
ca := cache.New(mc)
deg := cache.NewDegrader(10)
pc := cache.NewPage(10)
ctl := cache.NewControl(10)
filter := func(ctx *blademaster.Context) bool {
if ctx.Request.Form.Get("cache") == "false" {
return false
}
return true
}
engine := blademaster.Default()
engine.GET("/users/profile", ca.Cache(deg.Args("name", "age"), nil), func(c *blademaster.Context) {
values := c.Request.URL.Query()
name := values.Get("name")
age := values.Get("age")
err := errors.New("error from others") // error from other call
if err != nil {
// mark this response should be degraded
c.JSON(nil, ecode.Degrade)
return
}
c.JSON(map[string]string{"name": name, "age": age}, nil)
})
engine.GET("/users/index", ca.Cache(pc, nil), func(c *blademaster.Context) {
c.String(200, "%s", "Title: User")
})
engine.GET("/users/list", ca.Cache(ctl, filter), func(c *blademaster.Context) {
c.JSON([]string{"user1", "user2", "user3"}, nil)
})
engine.Run(":18080")
}

View File

@@ -0,0 +1,171 @@
package cache
import (
"bytes"
"crypto/sha1"
"io"
"net/http"
"net/url"
"sync"
"go-common/library/log"
bm "go-common/library/net/http/blademaster"
"go-common/library/net/http/blademaster/middleware/cache/store"
proto "github.com/gogo/protobuf/proto"
)
// consts for blademaster cache
const (
_pagePrefix = "bm.page"
)
// Page is used to cache common response
type Page struct {
Expire int32
pool sync.Pool
}
type cachedWriter struct {
ctx *bm.Context
response http.ResponseWriter
store store.Store
status int
expire int32
key string
}
var _ http.ResponseWriter = &cachedWriter{}
// NewPage will create a new page cache struct
func NewPage(expire int32) *Page {
pc := &Page{
Expire: expire,
}
pc.pool.New = func() interface{} {
return &cachedWriter{}
}
return pc
}
// Key is used to identify response cache key in most key-value store
func (p *Page) Key(ctx *bm.Context) string {
url := ctx.Request.URL
key := urlEscape(_pagePrefix, url.RequestURI())
return key
}
// Handler is used to execute cache service
func (p *Page) Handler(store store.Store) bm.HandlerFunc {
return func(ctx *bm.Context) {
var (
resp *ResponseCache
cached []byte
err error
)
key := p.Key(ctx)
cached, err = store.Get(ctx, key)
// if we did got the previous cache,
// try to unmarshal it
if err == nil && len(cached) > 0 {
resp = new(ResponseCache)
err = proto.Unmarshal(cached, resp)
}
// if we failed to fetch the cache or failed to parse cached data,
// then consider try to cache this response
if err != nil || resp == nil {
writer := p.pool.Get().(*cachedWriter)
writer.ctx = ctx
writer.response = ctx.Writer
writer.key = key
writer.expire = p.Expire
writer.store = store
ctx.Writer = writer
ctx.Next()
p.pool.Put(writer)
return
}
// write cached response
headers := ctx.Writer.Header()
for key, value := range resp.Header {
headers[key] = value.Value
}
ctx.Writer.WriteHeader(int(resp.Status))
ctx.Writer.Write(resp.Data)
ctx.Abort()
}
}
func (w *cachedWriter) Header() http.Header {
return w.response.Header()
}
func (w *cachedWriter) WriteHeader(code int) {
w.status = int(code)
w.response.WriteHeader(code)
}
func (w *cachedWriter) Write(data []byte) (size int, err error) {
var (
origin []byte
pdata []byte
)
if size, err = w.response.Write(data); err != nil {
return
}
store := w.store
origin, err = store.Get(w.ctx, w.key)
resp := new(ResponseCache)
if err == nil || len(origin) > 0 {
err1 := proto.Unmarshal(origin, resp)
if err1 == nil {
data = append(resp.Data, data...)
}
}
resp.Status = int32(w.status)
resp.Header = headerValues(w.Header())
resp.Data = data
if pdata, err = proto.Marshal(resp); err != nil {
// cannot happen
log.Error("Failed to marshal response to protobuf: %v", err)
return
}
if err = store.Set(w.ctx, w.key, pdata, w.expire); err != nil {
log.Error("Failed to set response cache: %v", err)
return
}
return
}
func headerValues(headers http.Header) map[string]*HeaderValue {
result := make(map[string]*HeaderValue, len(headers))
for key, values := range headers {
result[key] = &HeaderValue{
Value: values,
}
}
return result
}
func urlEscape(prefix string, u string) string {
key := url.QueryEscape(u)
if len(key) > 200 {
h := sha1.New()
io.WriteString(h, u)
key = string(h.Sum(nil))
}
var buffer bytes.Buffer
buffer.WriteString(prefix)
buffer.WriteString(":")
buffer.WriteString(key)
return buffer.String()
}

View File

@@ -0,0 +1,104 @@
// Code generated by protoc-gen-gogo. DO NOT EDIT.
// source: page.proto
/*
Package cache is a generated protocol buffer package.
It is generated from these files:
page.proto
It has these top-level messages:
ResponseCache
HeaderValue
*/
package cache
import proto "github.com/gogo/protobuf/proto"
import fmt "fmt"
import math "math"
import _ "github.com/gogo/protobuf/gogoproto"
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
var _ = fmt.Errorf
var _ = math.Inf
// This is a compile-time assertion to ensure that this generated file
// is compatible with the proto package it is being compiled against.
// A compilation error at this line likely means your copy of the
// proto package needs to be updated.
const _ = proto.GoGoProtoPackageIsVersion2 // please upgrade the proto package
type ResponseCache struct {
Status int32 `protobuf:"varint,1,opt,name=Status,proto3" json:"Status,omitempty"`
Header map[string]*HeaderValue `protobuf:"bytes,2,rep,name=Header" json:"Header,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value"`
Data []byte `protobuf:"bytes,3,opt,name=Data,proto3" json:"Data,omitempty"`
}
func (m *ResponseCache) Reset() { *m = ResponseCache{} }
func (m *ResponseCache) String() string { return proto.CompactTextString(m) }
func (*ResponseCache) ProtoMessage() {}
func (*ResponseCache) Descriptor() ([]byte, []int) { return fileDescriptorPage, []int{0} }
func (m *ResponseCache) GetStatus() int32 {
if m != nil {
return m.Status
}
return 0
}
func (m *ResponseCache) GetHeader() map[string]*HeaderValue {
if m != nil {
return m.Header
}
return nil
}
func (m *ResponseCache) GetData() []byte {
if m != nil {
return m.Data
}
return nil
}
type HeaderValue struct {
Value []string `protobuf:"bytes,1,rep,name=Value" json:"Value,omitempty"`
}
func (m *HeaderValue) Reset() { *m = HeaderValue{} }
func (m *HeaderValue) String() string { return proto.CompactTextString(m) }
func (*HeaderValue) ProtoMessage() {}
func (*HeaderValue) Descriptor() ([]byte, []int) { return fileDescriptorPage, []int{1} }
func (m *HeaderValue) GetValue() []string {
if m != nil {
return m.Value
}
return nil
}
func init() {
proto.RegisterType((*ResponseCache)(nil), "cache.responseCache")
proto.RegisterType((*HeaderValue)(nil), "cache.headerValue")
}
func init() { proto.RegisterFile("page.proto", fileDescriptorPage) }
var fileDescriptorPage = []byte{
// 231 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x54, 0x8f, 0x41, 0x4b, 0xc4, 0x30,
0x10, 0x85, 0x49, 0x6b, 0x0b, 0x3b, 0x55, 0x90, 0x41, 0x24, 0xec, 0x29, 0xac, 0x97, 0x5c, 0xcc,
0xc2, 0x7a, 0x59, 0xbc, 0xaa, 0xe0, 0xc5, 0x4b, 0x04, 0xef, 0x69, 0x1d, 0x5b, 0x51, 0x37, 0xa5,
0x4d, 0x84, 0xfd, 0x7f, 0xfe, 0x30, 0xe9, 0xa4, 0x87, 0xee, 0xed, 0x3d, 0xde, 0xf7, 0xe6, 0x31,
0x00, 0xbd, 0x6b, 0xc9, 0xf4, 0x83, 0x0f, 0x1e, 0x8b, 0xc6, 0x35, 0x1d, 0xad, 0x6f, 0xdb, 0xcf,
0xd0, 0xc5, 0xda, 0x34, 0xfe, 0x67, 0xdb, 0xfa, 0xd6, 0x6f, 0x39, 0xad, 0xe3, 0x07, 0x3b, 0x36,
0xac, 0x52, 0x6b, 0xf3, 0x27, 0xe0, 0x62, 0xa0, 0xb1, 0xf7, 0x87, 0x91, 0x1e, 0xa6, 0x03, 0x78,
0x0d, 0xe5, 0x6b, 0x70, 0x21, 0x8e, 0x52, 0x28, 0xa1, 0x0b, 0x3b, 0x3b, 0xdc, 0x43, 0xf9, 0x4c,
0xee, 0x9d, 0x06, 0x99, 0xa9, 0x5c, 0x57, 0x3b, 0x65, 0x78, 0xd0, 0x9c, 0xb4, 0x4d, 0x42, 0x9e,
0x0e, 0x61, 0x38, 0xda, 0x99, 0x47, 0x84, 0xb3, 0x47, 0x17, 0x9c, 0xcc, 0x95, 0xd0, 0xe7, 0x96,
0xf5, 0xfa, 0x05, 0xaa, 0x05, 0x8a, 0x97, 0x90, 0x7f, 0xd1, 0x91, 0x17, 0x57, 0x76, 0x92, 0xa8,
0xa1, 0xf8, 0x75, 0xdf, 0x91, 0x64, 0xa6, 0x84, 0xae, 0x76, 0x38, 0xaf, 0x75, 0x5c, 0x7a, 0x9b,
0x12, 0x9b, 0x80, 0xfb, 0x6c, 0x2f, 0x36, 0x37, 0x50, 0x2d, 0x12, 0xbc, 0x82, 0x82, 0x85, 0x14,
0x2a, 0xd7, 0x2b, 0x9b, 0x4c, 0x5d, 0xf2, 0xcb, 0x77, 0xff, 0x01, 0x00, 0x00, 0xff, 0xff, 0x6f,
0x05, 0xcf, 0xb0, 0x36, 0x01, 0x00, 0x00,
}

View File

@@ -0,0 +1,13 @@
syntax = "proto3";
package cache;
import "github.com/gogo/protobuf/gogoproto/gogo.proto";
message responseCache {
int32 Status = 1;
map<string, headerValue> Header = 2;
bytes Data = 3;
}
message headerValue {
repeated string Value = 1;
}

View File

@@ -0,0 +1,37 @@
package(default_visibility = ["//visibility:public"])
load(
"@io_bazel_rules_go//go:def.bzl",
"go_library",
)
go_library(
name = "go_default_library",
srcs = [
"file.go",
"memcache.go",
"store.go",
],
importpath = "go-common/library/net/http/blademaster/middleware/cache/store",
tags = ["automanaged"],
visibility = ["//visibility:public"],
deps = [
"//library/cache/memcache:go_default_library",
"//library/log:go_default_library",
"//vendor/github.com/pkg/errors:go_default_library",
],
)
filegroup(
name = "package-srcs",
srcs = glob(["**"]),
tags = ["automanaged"],
visibility = ["//visibility:private"],
)
filegroup(
name = "all-srcs",
srcs = [":package-srcs"],
tags = ["automanaged"],
visibility = ["//visibility:public"],
)

View File

@@ -0,0 +1,65 @@
package store
import (
"context"
"io/ioutil"
"os"
"path"
"go-common/library/log"
"github.com/pkg/errors"
)
// FileConfig config of File.
type FileConfig struct {
RootDir string
}
// File is a degrade file service.
type File struct {
c *FileConfig
}
var _ Store = &File{}
// NewFile new a file degrade service.
func NewFile(fc *FileConfig) *File {
if fc == nil {
panic(errors.New("file config is nil"))
}
fs := &File{c: fc}
if err := os.MkdirAll(fs.c.RootDir, 0755); err != nil {
panic(errors.Wrapf(err, "dir: %s", fs.c.RootDir))
}
return fs
}
// Set save the result of location to file.
// expire is not implemented in file storage.
func (fs *File) Set(ctx context.Context, key string, bs []byte, _ int32) (err error) {
file := path.Join(fs.c.RootDir, key)
tmp := file + ".tmp"
if err = ioutil.WriteFile(tmp, bs, 0644); err != nil {
log.Error("ioutil.WriteFile(%s, bs, 0644): error(%v)", tmp, err)
err = errors.Wrapf(err, "key: %s", key)
return
}
if err = os.Rename(tmp, file); err != nil {
log.Error("os.Rename(%s, %s): error(%v)", tmp, file, err)
err = errors.Wrapf(err, "key: %s", key)
return
}
return
}
// Get get result from file by locaiton+params.
func (fs *File) Get(ctx context.Context, key string) (bs []byte, err error) {
p := path.Join(fs.c.RootDir, key)
if bs, err = ioutil.ReadFile(p); err != nil {
log.Error("ioutil.ReadFile(%s): error(%v)", p, err)
err = errors.Wrapf(err, "key: %s", key)
return
}
return
}

View File

@@ -0,0 +1,54 @@
package store
import (
"context"
"go-common/library/cache/memcache"
"go-common/library/log"
)
// Memcache represents the cache with memcached persistence
type Memcache struct {
pool *memcache.Pool
}
// NewMemcache new a memcache store.
func NewMemcache(c *memcache.Config) *Memcache {
if c == nil {
panic("cache config is nil")
}
return &Memcache{
pool: memcache.NewPool(c),
}
}
// Set save the result to memcache store.
func (ms *Memcache) Set(ctx context.Context, key string, value []byte, expire int32) (err error) {
item := &memcache.Item{
Key: key,
Value: value,
Expiration: expire,
}
conn := ms.pool.Get(ctx)
defer conn.Close()
if err = conn.Set(item); err != nil {
log.Error("conn.Set(%s) error(%v)", key, err)
}
return
}
// Get get result from mc by locaiton+params.
func (ms *Memcache) Get(ctx context.Context, key string) ([]byte, error) {
conn := ms.pool.Get(ctx)
defer conn.Close()
r, err := conn.Get(key)
if err != nil {
if err == memcache.ErrNotFound {
//ignore not found error
return nil, nil
}
log.Error("conn.Get(%s) error(%v)", key, err)
return nil, err
}
return r.Value, nil
}

View File

@@ -0,0 +1,15 @@
package store
import (
"context"
)
// Store is the interface of a cache backend
type Store interface {
// Get retrieves an item from the cache. Returns the item or nil, and a bool indicating
// whether the key was found.
Get(ctx context.Context, key string) ([]byte, error)
// Set sets an item to the cache, replacing any existing item.
Set(ctx context.Context, key string, value []byte, expire int32) error
}

View File

@@ -0,0 +1,55 @@
package(default_visibility = ["//visibility:public"])
load(
"@io_bazel_rules_go//go:def.bzl",
"go_test",
"go_library",
)
go_test(
name = "go_default_test",
srcs = ["aqm_test.go"],
embed = [":go_default_library"],
rundir = ".",
tags = ["automanaged"],
deps = [
"//library/log:go_default_library",
"//library/net/http/blademaster:go_default_library",
],
)
go_library(
name = "go_default_library",
srcs = ["aqm.go"],
importpath = "go-common/library/net/http/blademaster/middleware/limit/aqm",
tags = ["automanaged"],
visibility = ["//visibility:public"],
deps = [
"//library/container/queue/aqm:go_default_library",
"//library/ecode:go_default_library",
"//library/net/http/blademaster:go_default_library",
"//library/rate:go_default_library",
"//library/rate/limit:go_default_library",
"//library/stat/prom:go_default_library",
],
)
go_test(
name = "go_default_xtest",
srcs = ["example_test.go"],
tags = ["automanaged"],
)
filegroup(
name = "package-srcs",
srcs = glob(["**"]),
tags = ["automanaged"],
visibility = ["//visibility:private"],
)
filegroup(
name = "all-srcs",
srcs = [":package-srcs"],
tags = ["automanaged"],
visibility = ["//visibility:public"],
)

View File

@@ -0,0 +1,5 @@
### business/blademaster/supervisor
##### Version 1.0.0
1. 完成基本功能与测试

View File

@@ -0,0 +1,5 @@
# Author
lintnaghui
# Reviewer
maojian

View File

@@ -0,0 +1,7 @@
# See the OWNERS docs at https://go.k8s.io/owners
approvers:
- lintnaghui
reviewers:
- lintnaghui
- maojian

View File

@@ -0,0 +1,13 @@
#### business/blademaster/supervisor
##### 项目简介
blademaster 的 aqm middleware主动队列管理请求延迟检测于优先级策略管理
##### 编译环境
- **请只用 Golang v1.8.x 以上版本编译执行**
##### 依赖包
- No other dependency

View File

@@ -0,0 +1,54 @@
package aqm
import (
"context"
"go-common/library/container/queue/aqm"
"go-common/library/ecode"
bm "go-common/library/net/http/blademaster"
"go-common/library/rate"
"go-common/library/rate/limit"
"go-common/library/stat/prom"
)
const (
_family = "blademaster"
)
var (
stats = prom.New().WithState("go_active_queue_mng", []string{"family", "title"})
)
// AQM aqm midleware.
type AQM struct {
limiter rate.Limiter
}
// New return a ratelimit midleware.
func New(conf *aqm.Config) (s *AQM) {
return &AQM{
limiter: limit.New(conf),
}
}
// Limit return a bm handler func.
func (a *AQM) Limit() bm.HandlerFunc {
return func(c *bm.Context) {
done, err := a.limiter.Allow(c)
if err != nil {
stats.Incr(_family, c.Request.URL.Path[1:])
// TODO: priority request.
// c.JSON(nil, err)
// c.Abort()
return
}
defer func() {
if c.Error != nil && !ecode.Deadline.Equal(c.Error) && c.Err() != context.DeadlineExceeded {
done(rate.Ignore)
return
}
done(rate.Success)
}()
c.Next()
}
}

View File

@@ -0,0 +1,47 @@
package aqm
import (
"fmt"
"math/rand"
"net/http"
"sync"
"sync/atomic"
"testing"
"time"
"go-common/library/log"
bm "go-common/library/net/http/blademaster"
)
func init() {
log.Init(nil)
}
func TestAQM(t *testing.T) {
var group sync.WaitGroup
rand.Seed(time.Now().Unix())
eng := bm.Default()
router := eng.Use(New(nil).Limit())
router.GET("/aqm", testaqm)
go eng.Run(":9999")
var errcount int64
for i := 0; i < 100; i++ {
group.Add(1)
go func() {
defer group.Done()
for j := 0; j < 300; j++ {
_, err := http.Get("http://127.0.0.1:9999/aqm")
if err != nil {
atomic.AddInt64(&errcount, 1)
}
}
}()
}
group.Wait()
fmt.Println("errcount", errcount)
}
func testaqm(ctx *bm.Context) {
count := rand.Intn(100)
time.Sleep(time.Millisecond * time.Duration(count))
}

View File

@@ -0,0 +1,9 @@
package aqm_test
// This example create a supervisor middleware instance and attach to a blademaster engine,
// it will allow '/ping' API can be requested with specified policy.
// This example will block all http method except `GET` on '/ping' API in next hour,
// and allow in further.
func Example() {
}

View File

@@ -0,0 +1,74 @@
package(default_visibility = ["//visibility:public"])
load(
"@io_bazel_rules_go//go:def.bzl",
"go_test",
"go_library",
)
go_test(
name = "go_default_test",
srcs = ["permit_test.go"],
embed = [":go_default_library"],
rundir = ".",
tags = ["automanaged"],
deps = [
"//library/cache/memcache:go_default_library",
"//library/container/pool:go_default_library",
"//library/ecode:go_default_library",
"//library/log:go_default_library",
"//library/net/http/blademaster:go_default_library",
"//library/net/netutil/breaker:go_default_library",
"//library/time:go_default_library",
],
)
go_library(
name = "go_default_library",
srcs = [
"permit.go",
"session.go",
],
importpath = "go-common/library/net/http/blademaster/middleware/permit",
tags = ["automanaged"],
visibility = ["//visibility:public"],
deps = [
"//app/admin/main/manager/api:go_default_library",
"//library/cache/memcache:go_default_library",
"//library/ecode:go_default_library",
"//library/log:go_default_library",
"//library/net/http/blademaster:go_default_library",
"//library/net/metadata:go_default_library",
"//library/net/rpc/warden:go_default_library",
"//vendor/github.com/pkg/errors:go_default_library",
],
)
go_test(
name = "go_default_xtest",
srcs = ["example_test.go"],
tags = ["automanaged"],
deps = [
"//library/cache/memcache:go_default_library",
"//library/container/pool:go_default_library",
"//library/net/http/blademaster:go_default_library",
"//library/net/http/blademaster/middleware/permit:go_default_library",
"//library/net/metadata:go_default_library",
"//library/net/netutil/breaker:go_default_library",
"//library/time:go_default_library",
],
)
filegroup(
name = "package-srcs",
srcs = glob(["**"]),
tags = ["automanaged"],
visibility = ["//visibility:private"],
)
filegroup(
name = "all-srcs",
srcs = [":package-srcs"],
tags = ["automanaged"],
visibility = ["//visibility:public"],
)

View File

@@ -0,0 +1,24 @@
### business/blademaster/permit
### Version 1.0.6
1. fix manager用户不存在的情况
### Version 1.0.5
1. permit无配置化
### Version 1.0.4
1. auth添加默认配置
### Version 1.0.3
1. auth去掉写cookie操作
##### Version 1.0.2
1. 修复dashboard auth cookie校验username问题
##### Version 1.0.1
1. auth后将username写入context
##### Version 1.0.0
1. 使用 dashbaord 来认证用户使用manager来获取用户权限使用 memcache 来缓存用户Session
2. 完成基本功能与测试

View File

@@ -0,0 +1,13 @@
#### business/blademaster/identify
##### 项目简介
blademaster 的 auth middleware主要用于设置路由的后台管理系统的登陆验证和鉴权策略
##### 编译环境
- **请只用 Golang v1.8.x 以上版本编译执行**
##### 依赖包
- No other dependency

View File

@@ -0,0 +1,101 @@
package permit_test
import (
"fmt"
"time"
"go-common/library/cache/memcache"
"go-common/library/container/pool"
bm "go-common/library/net/http/blademaster"
"go-common/library/net/http/blademaster/middleware/permit"
"go-common/library/net/metadata"
"go-common/library/net/netutil/breaker"
xtime "go-common/library/time"
)
// This example create a permit middleware instance and attach to several path,
// it will validate request by specified policy and put extra information into context. e.g., `uid`.
// It provides additional handler functions to provide the identification for your business handler.
func Example() {
a := permit.New(&permit.Config{
DsHTTPClient: &bm.ClientConfig{
App: &bm.App{
Key: "manager-go",
Secret: "949bbb2dd3178252638c2407578bc7ad",
},
Dial: xtime.Duration(time.Second),
Timeout: xtime.Duration(time.Second),
KeepAlive: xtime.Duration(time.Second * 10),
Breaker: &breaker.Config{
Window: xtime.Duration(time.Second),
Sleep: xtime.Duration(time.Millisecond * 100),
Bucket: 10,
Ratio: 0.5,
Request: 100,
},
},
MaHTTPClient: &bm.ClientConfig{
App: &bm.App{
Key: "f6433799dbd88751",
Secret: "36f8ddb1806207fe07013ab6a77a3935",
},
Dial: xtime.Duration(time.Second),
Timeout: xtime.Duration(time.Second),
KeepAlive: xtime.Duration(time.Second * 10),
Breaker: &breaker.Config{
Window: xtime.Duration(time.Second),
Sleep: xtime.Duration(time.Millisecond * 100),
Bucket: 10,
Ratio: 0.5,
Request: 100,
},
},
Session: &permit.SessionConfig{
SessionIDLength: 32,
CookieLifeTime: 1800,
CookieName: "mng-go",
Domain: ".bilibili.co",
Memcache: &memcache.Config{
Config: &pool.Config{
Active: 10,
Idle: 5,
IdleTimeout: xtime.Duration(time.Second * 80),
},
Name: "go-business/permit",
Proto: "tcp",
Addr: "172.16.33.54:11211",
DialTimeout: xtime.Duration(time.Millisecond * 1000),
ReadTimeout: xtime.Duration(time.Millisecond * 1000),
WriteTimeout: xtime.Duration(time.Millisecond * 1000),
},
},
ManagerHost: "http://uat-manager.bilibili.co",
DashboardHost: "http://uat-dashboard-mng.bilibili.co",
DashboardCaller: "manager-go",
})
p := permit.New2(nil)
e := bm.NewServer(nil)
//Check whether the user has logged in
e.GET("/login", a.Verify(), func(c *bm.Context) {
c.JSON("pass", nil)
})
//Check whether the user has logged in,and check th user has the access permisson to the specifed path
e.GET("/tag/del", a.Permit("TAG_DEL"), func(c *bm.Context) {
uid := metadata.Int64(c, metadata.Uid)
username := metadata.String(c, metadata.Username)
c.JSON(fmt.Sprintf("pass uid(%d) username(%s)", uid, username), nil)
})
e.GET("/check/login", p.Verify2(), func(c *bm.Context) {
c.JSON("pass", nil)
})
e.POST("/tag/del", p.Permit2("TAG_DEL"), func(c *bm.Context) {
uid := metadata.Int64(c, metadata.Uid)
username := metadata.String(c, metadata.Username)
c.JSON(fmt.Sprintf("pass uid(%d) username(%s)", uid, username), nil)
})
e.Run(":18080")
}

View File

@@ -0,0 +1,321 @@
package permit
import (
"net/url"
mng "go-common/app/admin/main/manager/api"
"go-common/library/ecode"
"go-common/library/log"
bm "go-common/library/net/http/blademaster"
"go-common/library/net/metadata"
"go-common/library/net/rpc/warden"
"github.com/pkg/errors"
)
const (
_verifyURI = "/api/session/verify"
_permissionURI = "/x/admin/manager/permission"
_sessIDKey = "_AJSESSIONID"
_sessUIDKey = "uid" // manager user_id
_sessUnKey = "username" // LDAP username
_defaultDomain = ".bilibili.co"
_defaultCookieName = "mng-go"
_defaultCookieLifeTime = 2592000
// CtxPermissions will be set into ctx.
CtxPermissions = "permissions"
)
// permissions .
type permissions struct {
UID int64 `json:"uid"`
Perms []string `json:"perms"`
}
// Permit is an auth middleware.
type Permit struct {
verifyURI string
permissionURI string
dashboardCaller string
dsClient *bm.Client // dashboard client
maClient *bm.Client // manager-admin client
sm *SessionManager // user Session
mng.PermitClient // mng grpc client
}
//Verify only export Verify function because of less configure
type Verify interface {
Verify() bm.HandlerFunc
}
// Config identify config.
type Config struct {
DsHTTPClient *bm.ClientConfig // dashboard client config. appkey can not reuse.
MaHTTPClient *bm.ClientConfig // manager-admin client config
Session *SessionConfig
ManagerHost string
DashboardHost string
DashboardCaller string
}
// Config2 .
type Config2 struct {
MngClient *warden.ClientConfig
Session *SessionConfig
}
// New new an auth service.
func New(c *Config) *Permit {
a := &Permit{
dashboardCaller: c.DashboardCaller,
verifyURI: c.DashboardHost + _verifyURI,
permissionURI: c.ManagerHost + _permissionURI,
dsClient: bm.NewClient(c.DsHTTPClient),
maClient: bm.NewClient(c.MaHTTPClient),
sm: newSessionManager(c.Session),
}
return a
}
// New2 .
func New2(c *warden.ClientConfig) *Permit {
permitClient, err := mng.NewClient(c)
if err != nil {
panic(errors.WithMessage(err, "Failed to dial mng rpc server"))
}
return &Permit{
PermitClient: permitClient,
sm: &SessionManager{},
}
}
// NewVerify new a verify service.
func NewVerify(c *Config) Verify {
a := &Permit{
verifyURI: c.DashboardHost + _verifyURI,
dsClient: bm.NewClient(c.DsHTTPClient),
dashboardCaller: c.DashboardCaller,
sm: newSessionManager(c.Session),
}
return a
}
// Verify2 check whether the user has logged in.
func (p *Permit) Verify2() bm.HandlerFunc {
return func(ctx *bm.Context) {
sid, username, err := p.login2(ctx)
if err != nil {
ctx.JSON(nil, ecode.Unauthorized)
ctx.Abort()
return
}
ctx.Set(_sessUnKey, username)
p.sm.setHTTPCookie(ctx, _defaultCookieName, sid)
}
}
// Verify return bm HandlerFunc which check whether the user has logged in.
func (p *Permit) Verify() bm.HandlerFunc {
return func(ctx *bm.Context) {
si, err := p.login(ctx)
if err != nil {
ctx.JSON(nil, ecode.Unauthorized)
ctx.Abort()
return
}
p.sm.SessionRelease(ctx, si)
}
}
// Permit return bm HandlerFunc which check whether the user has logged in and has the access permission of the location.
// If `permit` is empty,it will allow any logged in request.
func (p *Permit) Permit(permit string) bm.HandlerFunc {
return func(ctx *bm.Context) {
var (
si *Session
uid int64
perms []string
err error
)
si, err = p.login(ctx)
if err != nil {
ctx.JSON(nil, ecode.Unauthorized)
ctx.Abort()
return
}
defer p.sm.SessionRelease(ctx, si)
uid, perms, err = p.permissions(ctx, si.Get(_sessUnKey).(string))
if err == nil {
si.Set(_sessUIDKey, uid)
ctx.Set(_sessUIDKey, uid)
if md, ok := metadata.FromContext(ctx); ok {
md[metadata.Uid] = uid
}
}
if len(perms) > 0 {
ctx.Set(CtxPermissions, perms)
}
if !p.permit(permit, perms) {
ctx.JSON(nil, ecode.AccessDenied)
ctx.Abort()
return
}
}
}
// login check whether the user has logged in.
func (p *Permit) login(ctx *bm.Context) (si *Session, err error) {
si = p.sm.SessionStart(ctx)
if si.Get(_sessUnKey) == nil {
var username string
if username, err = p.verify(ctx); err != nil {
return
}
si.Set(_sessUnKey, username)
}
ctx.Set(_sessUnKey, si.Get(_sessUnKey))
if md, ok := metadata.FromContext(ctx); ok {
md[metadata.Username] = si.Get(_sessUnKey)
}
return
}
// Permit2 same function as permit function but reply on grpc.
func (p *Permit) Permit2(permit string) bm.HandlerFunc {
return func(ctx *bm.Context) {
sid, username, err := p.login2(ctx)
if err != nil {
ctx.JSON(nil, ecode.Unauthorized)
ctx.Abort()
return
}
p.sm.setHTTPCookie(ctx, _defaultCookieName, sid)
ctx.Set(_sessUnKey, username)
if md, ok := metadata.FromContext(ctx); ok {
md[metadata.Username] = username
}
reply, err := p.Permissions(ctx, &mng.PermissionReq{Username: username})
if err != nil {
if ecode.NothingFound.Equal(err) && permit != "" {
ctx.JSON(nil, ecode.AccessDenied)
ctx.Abort()
}
return
}
ctx.Set(_sessUIDKey, reply.Uid)
if md, ok := metadata.FromContext(ctx); ok {
md[metadata.Uid] = reply.Uid
}
if len(reply.Perms) > 0 {
ctx.Set(CtxPermissions, reply.Perms)
}
if !p.permit(permit, reply.Perms) {
ctx.JSON(nil, ecode.AccessDenied)
ctx.Abort()
return
}
}
}
// login2 .
func (p *Permit) login2(ctx *bm.Context) (sid, uname string, err error) {
var dsbsid, mngsid string
dsbck, err := ctx.Request.Cookie(_sessIDKey)
if err == nil {
dsbsid = dsbck.Value
}
if dsbsid == "" {
err = ecode.Unauthorized
return
}
mngck, err := ctx.Request.Cookie(_defaultCookieName)
if err == nil {
mngsid = mngck.Value
}
reply, err := p.Login(ctx, &mng.LoginReq{Mngsid: mngsid, Dsbsid: dsbsid})
if err != nil {
log.Error("mng rpc Login error(%v)", err)
return
}
sid = reply.Sid
uname = reply.Username
return
}
func (p *Permit) verify(ctx *bm.Context) (username string, err error) {
var (
sid string
r = ctx.Request
)
session, err := r.Cookie(_sessIDKey)
if err == nil {
sid = session.Value
}
if sid == "" {
err = ecode.Unauthorized
return
}
username, err = p.verifyDashboard(ctx, sid)
return
}
// permit check whether user has the access permission of the location.
func (p *Permit) permit(permit string, permissions []string) bool {
if permit == "" {
return true
}
for _, p := range permissions {
if p == permit {
// access the permit
return true
}
}
return false
}
// verifyDashboard check whether the user is valid from Dashboard.
func (p *Permit) verifyDashboard(ctx *bm.Context, sid string) (username string, err error) {
params := url.Values{}
params.Set("session_id", sid)
params.Set("encrypt", "md5")
params.Set("caller", p.dashboardCaller)
var res struct {
Code int `json:"code"`
UserName string `json:"username"`
}
if err = p.dsClient.Get(ctx, p.verifyURI, metadata.String(ctx, metadata.RemoteIP), params, &res); err != nil {
log.Error("dashboard get verify Session url(%s) error(%v)", p.verifyURI+"?"+params.Encode(), err)
return
}
if ecode.Int(res.Code) != ecode.OK {
log.Error("dashboard get verify Session url(%s) error(%v)", p.verifyURI+"?"+params.Encode(), res.Code)
err = ecode.Int(res.Code)
return
}
username = res.UserName
return
}
// permissions get user's permisssions from manager-admin.
func (p *Permit) permissions(ctx *bm.Context, username string) (uid int64, perms []string, err error) {
params := url.Values{}
params.Set(_sessUnKey, username)
var res struct {
Code int `json:"code"`
Data permissions `json:"data"`
}
if err = p.maClient.Get(ctx, p.permissionURI, metadata.String(ctx, metadata.RemoteIP), params, &res); err != nil {
log.Error("dashboard get permissions url(%s) error(%v)", p.permissionURI+"?"+params.Encode(), err)
return
}
if ecode.Int(res.Code) != ecode.OK {
log.Error("dashboard get permissions url(%s) error(%v)", p.permissionURI+"?"+params.Encode(), res.Code)
err = ecode.Int(res.Code)
return
}
perms = res.Data.Perms
uid = res.Data.UID
return
}

View File

@@ -0,0 +1,294 @@
package permit
import (
"context"
"net/http"
"net/url"
"sync"
"testing"
"time"
"go-common/library/cache/memcache"
"go-common/library/container/pool"
"go-common/library/ecode"
"go-common/library/log"
bm "go-common/library/net/http/blademaster"
"go-common/library/net/netutil/breaker"
xtime "go-common/library/time"
)
var (
once sync.Once
)
type Response struct {
Code int `json:"code"`
Data string `json:"data"`
}
func init() {
log.Init(nil)
}
func client() *bm.Client {
return bm.NewClient(&bm.ClientConfig{
App: &bm.App{
Key: "test",
Secret: "test",
},
Dial: xtime.Duration(time.Second),
Timeout: xtime.Duration(time.Second),
KeepAlive: xtime.Duration(time.Second * 10),
Breaker: &breaker.Config{
Window: xtime.Duration(time.Second),
Sleep: xtime.Duration(time.Millisecond * 100),
Bucket: 10,
Ratio: 0.5,
Request: 100,
},
})
}
func getPermit() *Permit {
return New(&Config{
DsHTTPClient: &bm.ClientConfig{
App: &bm.App{
Key: "manager-go",
Secret: "949bbb2dd3178252638c2407578bc7ad",
},
Dial: xtime.Duration(time.Second),
Timeout: xtime.Duration(time.Second),
KeepAlive: xtime.Duration(time.Second * 10),
Breaker: &breaker.Config{
Window: xtime.Duration(time.Second),
Sleep: xtime.Duration(time.Millisecond * 100),
Bucket: 10,
Ratio: 0.5,
Request: 100,
},
},
MaHTTPClient: &bm.ClientConfig{
App: &bm.App{
Key: "f6433799dbd88751",
Secret: "36f8ddb1806207fe07013ab6a77a3935",
},
Dial: xtime.Duration(time.Second),
Timeout: xtime.Duration(time.Second),
KeepAlive: xtime.Duration(time.Second * 10),
Breaker: &breaker.Config{
Window: xtime.Duration(time.Second),
Sleep: xtime.Duration(time.Millisecond * 100),
Bucket: 10,
Ratio: 0.5,
Request: 100,
},
},
Session: &SessionConfig{
SessionIDLength: 32,
CookieLifeTime: 1800,
CookieName: "mng-go",
Domain: ".bilibili.co",
Memcache: &memcache.Config{
Config: &pool.Config{
Active: 10,
Idle: 5,
IdleTimeout: xtime.Duration(time.Second * 80),
},
Name: "go-business/auth",
Proto: "tcp",
Addr: "172.16.33.54:11211",
DialTimeout: xtime.Duration(time.Millisecond * 1000),
ReadTimeout: xtime.Duration(time.Millisecond * 1000),
WriteTimeout: xtime.Duration(time.Millisecond * 1000),
},
},
ManagerHost: "http://uat-manager.bilibili.co",
DashboardHost: "http://dashboard-mng.bilibili.co",
DashboardCaller: "manager-go",
})
}
func engine() *bm.Engine {
e := bm.NewServer(nil)
a := getPermit()
e.GET("/login", a.Verify(), func(c *bm.Context) {
c.JSON("pass", nil)
})
e.GET("/tag/del", a.Permit("TAG_DEL"), func(c *bm.Context) {
c.JSON("pass", nil)
})
e.GET("/tag/admin", a.Permit("TAG_ADMIN"), func(c *bm.Context) {
c.JSON("pass", nil)
})
return e
}
func setSession(uid int64, username string) (string, error) {
a := getPermit()
sv := a.sm.newSession(context.TODO())
sv.Set("username", username)
mcConn := a.sm.mc.Get(context.TODO())
defer mcConn.Close()
key := sv.Sid
item := &memcache.Item{
Key: key,
Object: sv,
Flags: memcache.FlagJSON,
Expiration: int32(a.sm.c.CookieLifeTime),
}
if err := mcConn.Set(item); err != nil {
return "", err
}
return key, nil
}
func startEngine(t *testing.T) func() {
return func() {
e := engine()
err := e.Run(":18080")
if err != nil {
t.Fatalf("failed to run server!%v", err)
}
}
}
func TestLoginSuccess(t *testing.T) {
go once.Do(startEngine(t))
time.Sleep(time.Millisecond * 100)
sid, err := setSession(2233, "caoguoliang")
if err != nil {
t.Fatalf("faild to set session !err:=%v", err)
}
query := url.Values{}
query.Set("test", "test")
cli := client()
req, err := cli.NewRequest("GET", "http://127.0.0.1:18080/login", "", query)
if err != nil {
t.Fatalf("Failed to build request: %v", err)
}
req.AddCookie(&http.Cookie{
Name: "mng-go",
Value: sid,
})
req.AddCookie(&http.Cookie{
Name: "username",
Value: "caoguoliang",
})
req.AddCookie(&http.Cookie{
Name: "_AJSESSIONID",
Value: "87fa8450e93511e79ed8522233007f8a",
})
res := Response{}
if err := cli.Do(context.TODO(), req, &res); err != nil {
t.Fatalf("Failed to send request: %v", err)
}
if res.Code != 0 || res.Data != "pass" {
t.Fatalf("Unexpected response code(%d) data(%v)", res.Code, res.Data)
}
}
func TestLoginFail(t *testing.T) {
go once.Do(startEngine(t))
time.Sleep(time.Millisecond * 100)
query := url.Values{}
query.Set("test", "test")
cli := client()
req, err := cli.NewRequest("GET", "http://127.0.0.1:18080/login", "", query)
if err != nil {
t.Fatalf("Failed to build request: %v", err)
}
req.AddCookie(&http.Cookie{
Name: "mng-go",
Value: "fakesess",
})
req.AddCookie(&http.Cookie{
Name: "username",
Value: "caoguoliang",
})
req.AddCookie(&http.Cookie{
Name: "_AJSESSIONID",
Value: "testsess",
})
res := Response{}
if err := cli.Do(context.TODO(), req, &res); err != nil {
t.Fatalf("Failed to send request: %v", err)
}
if res.Code != ecode.Unauthorized.Code() {
t.Fatalf("This request should be forbidden: code(%d) data(%v)", res.Code, res.Data)
}
}
func TestVerifySuccess(t *testing.T) {
go once.Do(startEngine(t))
time.Sleep(time.Millisecond * 100)
sid, err := setSession(2233, "caoguoliang")
if err != nil {
t.Fatalf("faild to set session !err:=%v", err)
}
query := url.Values{}
query.Set("test", "test")
cli := client()
req, err := cli.NewRequest("GET", "http://127.0.0.1:18080/tag/del", "", query)
if err != nil {
t.Fatalf("Failed to build request: %v", err)
}
req.AddCookie(&http.Cookie{
Name: "mng-go",
Value: sid,
})
req.AddCookie(&http.Cookie{
Name: "username",
Value: "caoguoliang",
})
req.AddCookie(&http.Cookie{
Name: "_AJSESSIONID",
Value: "87fa8450e93511e79ed8522233007f8a",
})
res := Response{}
if err := cli.Do(context.TODO(), req, &res); err != nil {
t.Fatalf("Failed to send request: %v", err)
}
if res.Code != 0 || res.Data != "pass" {
t.Fatalf("Unexpected response code(%d) data(%v)", res.Code, res.Data)
}
}
func TestVerifyFail(t *testing.T) {
go once.Do(startEngine(t))
time.Sleep(time.Millisecond * 100)
sid, err := setSession(2233, "caoguoliang")
if err != nil {
t.Fatalf("faild to set session !err:=%v", err)
}
query := url.Values{}
query.Set("test", "test")
cli := client()
req, err := cli.NewRequest("GET", "http://127.0.0.1:18080/tag/admin", "", query)
if err != nil {
t.Fatalf("Failed to build request: %v", err)
}
req.AddCookie(&http.Cookie{
Name: "mng-go",
Value: sid,
})
req.AddCookie(&http.Cookie{
Name: "username",
Value: "caoguoliang",
})
req.AddCookie(&http.Cookie{
Name: "_AJSESSIONID",
Value: "87fa8450e93511e79ed8522233007f8a",
})
res := Response{}
if err := cli.Do(context.TODO(), req, &res); err != nil {
t.Fatalf("Failed to send request: %v", err)
}
if res.Code != ecode.AccessDenied.Code() {
t.Fatalf("This request should be forbidden: code(%d) data(%v)", res.Code, res.Data)
}
}

View File

@@ -0,0 +1,152 @@
package permit
import (
"context"
"crypto/rand"
"encoding/hex"
"net/http"
"net/url"
"sync"
"time"
"go-common/library/cache/memcache"
"go-common/library/log"
bm "go-common/library/net/http/blademaster"
)
// Session http session.
type Session struct {
Sid string
lock sync.RWMutex
Values map[string]interface{}
}
// SessionConfig config of Session.
type SessionConfig struct {
SessionIDLength int
CookieLifeTime int
CookieName string
Domain string
Memcache *memcache.Config
}
// SessionManager .
type SessionManager struct {
mc *memcache.Pool // Session cache
c *SessionConfig
}
// newSessionManager .
func newSessionManager(c *SessionConfig) (s *SessionManager) {
s = &SessionManager{
mc: memcache.NewPool(c.Memcache),
c: c,
}
return
}
// SessionStart start session.
func (s *SessionManager) SessionStart(ctx *bm.Context) (si *Session) {
// check manager Session id, if err or no exist need new one.
if si, _ = s.cache(ctx); si == nil {
si = s.newSession(ctx)
}
return
}
// SessionRelease flush session into store.
func (s *SessionManager) SessionRelease(ctx *bm.Context, sv *Session) {
// set http cookie
s.setHTTPCookie(ctx, s.c.CookieName, sv.Sid)
// set mc
conn := s.mc.Get(ctx)
defer conn.Close()
key := sv.Sid
item := &memcache.Item{
Key: key,
Object: sv,
Flags: memcache.FlagJSON,
Expiration: int32(s.c.CookieLifeTime),
}
if err := conn.Set(item); err != nil {
log.Error("SessionManager set error(%s,%v)", key, err)
}
}
// SessionDestroy destroy session.
func (s *SessionManager) SessionDestroy(ctx *bm.Context, sv *Session) {
conn := s.mc.Get(ctx)
defer conn.Close()
if err := conn.Delete(sv.Sid); err != nil {
log.Error("SessionManager delete error(%s,%v)", sv.Sid, err)
}
}
func (s *SessionManager) cache(ctx *bm.Context) (res *Session, err error) {
ck, err := ctx.Request.Cookie(s.c.CookieName)
if err != nil || ck == nil {
return
}
sid := ck.Value
// get from cache
conn := s.mc.Get(ctx)
defer conn.Close()
r, err := conn.Get(sid)
if err != nil {
if err == memcache.ErrNotFound {
err = nil
return
}
log.Error("conn.Get(%s) error(%v)", sid, err)
return
}
res = &Session{}
if err = conn.Scan(r, res); err != nil {
log.Error("conn.Scan(%v) error(%v)", string(r.Value), err)
}
return
}
func (s *SessionManager) newSession(ctx context.Context) (res *Session) {
b := make([]byte, s.c.SessionIDLength)
n, err := rand.Read(b)
if n != len(b) || err != nil {
return nil
}
res = &Session{
Sid: hex.EncodeToString(b),
Values: make(map[string]interface{}),
}
return
}
func (s *SessionManager) setHTTPCookie(ctx *bm.Context, name, value string) {
cookie := &http.Cookie{
Name: name,
Value: url.QueryEscape(value),
Path: "/",
HttpOnly: true,
Domain: _defaultDomain,
}
cookie.MaxAge = _defaultCookieLifeTime
cookie.Expires = time.Now().Add(time.Duration(_defaultCookieLifeTime) * time.Second)
http.SetCookie(ctx.Writer, cookie)
}
// Get get value by key.
func (s *Session) Get(key string) (value interface{}) {
s.lock.RLock()
defer s.lock.RUnlock()
value = s.Values[key]
return
}
// Set set value into session.
func (s *Session) Set(key string, value interface{}) (err error) {
s.lock.Lock()
defer s.lock.Unlock()
s.Values[key] = value
return
}

View File

@@ -0,0 +1,59 @@
package(default_visibility = ["//visibility:public"])
load(
"@io_bazel_rules_go//go:def.bzl",
"go_test",
"go_library",
)
go_test(
name = "go_default_test",
srcs = ["proxy_test.go"],
embed = [":go_default_library"],
rundir = ".",
tags = ["automanaged"],
deps = [
"//library/log:go_default_library",
"//library/net/http/blademaster:go_default_library",
"//vendor/github.com/stretchr/testify/assert:go_default_library",
],
)
go_library(
name = "go_default_library",
srcs = ["proxy.go"],
importpath = "go-common/library/net/http/blademaster/middleware/proxy",
tags = ["automanaged"],
visibility = ["//visibility:public"],
deps = [
"//library/conf/env:go_default_library",
"//library/log:go_default_library",
"//library/net/http/blademaster:go_default_library",
"//library/net/metadata:go_default_library",
"//vendor/github.com/pkg/errors:go_default_library",
],
)
go_test(
name = "go_default_xtest",
srcs = ["example_test.go"],
tags = ["automanaged"],
deps = [
"//library/net/http/blademaster:go_default_library",
"//library/net/http/blademaster/middleware/proxy:go_default_library",
],
)
filegroup(
name = "package-srcs",
srcs = glob(["**"]),
tags = ["automanaged"],
visibility = ["//visibility:private"],
)
filegroup(
name = "all-srcs",
srcs = [":package-srcs"],
tags = ["automanaged"],
visibility = ["//visibility:public"],
)

View File

@@ -0,0 +1,5 @@
### business/blademaster/proxy
##### Version 1.0.0
1. 完成基本功能与测试

View File

@@ -0,0 +1,5 @@
# Author
zhoujiahui
# Reviewer
maojian

View File

@@ -0,0 +1,7 @@
# See the OWNERS docs at https://go.k8s.io/owners
approvers:
- zhoujiahui
reviewers:
- maojian
- zhoujiahui

View File

@@ -0,0 +1,13 @@
#### business/blademaster/proxy
##### 项目简介
blademaster 的 reverse proxy middleware主要用于转发一些 API 请求
##### 编译环境
- **请只用 Golang v1.8.x 以上版本编译执行**
##### 依赖包
- No other dependency

View File

@@ -0,0 +1,49 @@
package proxy_test
import (
"go-common/library/net/http/blademaster"
"go-common/library/net/http/blademaster/middleware/proxy"
)
// This example create several reverse proxy to show how to use proxy middleware.
// We proxy three path to `api.bilibili.com` and return response without any changes.
func Example() {
proxies := map[string]string{
"/index": "http://api.bilibili.com/html/index",
"/ping": "http://api.bilibili.com/api/ping",
"/api/versions": "http://api.bilibili.com/api/web/versions",
}
engine := blademaster.Default()
for path, ep := range proxies {
engine.GET(path, proxy.NewAlways(ep))
}
engine.Run(":18080")
}
// This example create several reverse proxy to show how to use jd proxy middleware.
// The request will be proxied to destination only when request is from specified datacenter.
func ExampleNewZoneProxy() {
proxies := map[string]string{
"/index": "http://api.bilibili.com/html/index",
"/ping": "http://api.bilibili.com/api/ping",
"/api/versions": "http://api.bilibili.com/api/web/versions",
}
engine := blademaster.Default()
// proxy to specified destination
for path, ep := range proxies {
engine.GET(path, proxy.NewZoneProxy("sh004", ep), func(ctx *blademaster.Context) {
ctx.String(200, "Origin")
})
}
// proxy with request path
ug := engine.Group("/update", proxy.NewZoneProxy("sh004", "http://sh001-api.bilibili.com"))
ug.POST("/name", func(ctx *blademaster.Context) {
ctx.String(500, "Should not be accessed")
})
ug.POST("/sign", func(ctx *blademaster.Context) {
ctx.String(500, "Should not be accessed")
})
engine.Run(":18080")
}

View File

@@ -0,0 +1,118 @@
package proxy
import (
"bytes"
"io"
"io/ioutil"
stdlog "log"
"net/http"
"net/http/httputil"
"net/url"
"go-common/library/conf/env"
"go-common/library/log"
bm "go-common/library/net/http/blademaster"
"go-common/library/net/metadata"
"github.com/pkg/errors"
)
type endpoint struct {
url *url.URL
proxy *httputil.ReverseProxy
condition func(ctx *bm.Context) bool
}
type logger struct{}
func (logger) Write(p []byte) (int, error) {
log.Warn("%s", string(p))
return len(p), nil
}
func newep(rawurl string, condition func(ctx *bm.Context) bool) *endpoint {
u, err := url.Parse(rawurl)
if err != nil {
panic(errors.Errorf("Invalid URL: %s", rawurl))
}
e := &endpoint{
url: u,
}
e.proxy = &httputil.ReverseProxy{
Director: e.director,
ErrorLog: stdlog.New(logger{}, "bm.proxy: ", stdlog.LstdFlags),
}
e.condition = condition
return e
}
func (e *endpoint) director(req *http.Request) {
req.URL.Scheme = e.url.Scheme
req.URL.Host = e.url.Host
// keep the origin request path
if e.url.Path != "" {
req.URL.Path = e.url.Path
}
body, length := rebuildBody(req)
req.Body = body
req.ContentLength = int64(length)
}
func (e *endpoint) ServeHTTP(ctx *bm.Context) {
req := ctx.Request
ip := metadata.String(ctx, metadata.RemoteIP)
logArgs := []log.D{
log.KV("method", req.Method),
log.KV("ip", ip),
log.KV("path", req.URL.Path),
log.KV("params", req.Form.Encode()),
}
if !e.condition(ctx) {
logArgs = append(logArgs, log.KV("proxied", "false"))
log.Infov(ctx, logArgs...)
return
}
logArgs = append(logArgs, log.KV("proxied", "true"))
log.Infov(ctx, logArgs...)
e.proxy.ServeHTTP(ctx.Writer, ctx.Request)
ctx.Abort()
}
func rebuildBody(req *http.Request) (io.ReadCloser, int) {
// GET request
if req.Body == nil {
return nil, 0
}
// Submit with form
if len(req.PostForm) > 0 {
br := bytes.NewReader([]byte(req.PostForm.Encode()))
return ioutil.NopCloser(br), br.Len()
}
// copy the original body
bodyBytes, _ := ioutil.ReadAll(req.Body)
br := bytes.NewReader(bodyBytes)
return ioutil.NopCloser(br), br.Len()
}
func always(ctx *bm.Context) bool {
return true
}
// NewZoneProxy is
func NewZoneProxy(matchZone, dst string) bm.HandlerFunc {
ep := newep(dst, func(*bm.Context) bool {
if env.Zone == matchZone {
return true
}
return false
})
return ep.ServeHTTP
}
// NewAlways is
func NewAlways(dst string) bm.HandlerFunc {
ep := newep(dst, always)
return ep.ServeHTTP
}

View File

@@ -0,0 +1,170 @@
package proxy
import (
"bytes"
"context"
"net/http"
"net/url"
"sync"
"testing"
"time"
"go-common/library/log"
bm "go-common/library/net/http/blademaster"
"github.com/stretchr/testify/assert"
)
func init() {
log.Init(nil)
}
func TestProxy(t *testing.T) {
engine := bm.Default()
engine.GET("/icon", NewAlways("http://api.bilibili.com/x/web-interface/index/icon"))
engine.POST("/x/web-interface/archive/like", NewAlways("http://api.bilibili.com"))
go engine.Run(":18080")
defer func() {
engine.Server().Shutdown(context.TODO())
}()
time.Sleep(time.Second)
req, err := http.NewRequest("GET", "http://127.0.0.1:18080/icon", nil)
assert.NoError(t, err)
req.Host = "api.bilibili.com"
resp, err := http.DefaultClient.Do(req)
assert.NoError(t, err)
defer resp.Body.Close()
assert.Equal(t, 200, resp.StatusCode)
// proxy form request
form := url.Values{}
form.Set("arg1", "1")
form.Set("arg2", "2")
req, err = http.NewRequest("POST", "http://127.0.0.1:18080/x/web-interface/archive/like?param=test", bytes.NewReader([]byte(form.Encode())))
assert.NoError(t, err)
req.Host = "api.bilibili.com"
resp, err = http.DefaultClient.Do(req)
assert.NoError(t, err)
defer resp.Body.Close()
assert.Equal(t, 200, resp.StatusCode)
// proxy json request
bs := []byte(`{"arg1": 1, "arg2": 2}`)
req, err = http.NewRequest("POST", "http://127.0.0.1:18080/x/web-interface/archive/like?param=test", bytes.NewReader(bs))
assert.NoError(t, err)
req.Host = "api.bilibili.com"
req.Header.Set("Content-Type", "application/json; charset=utf-8")
resp, err = http.DefaultClient.Do(req)
assert.NoError(t, err)
defer resp.Body.Close()
assert.Equal(t, 200, resp.StatusCode)
}
func TestProxyRace(t *testing.T) {
engine := bm.Default()
engine.GET("/icon", NewAlways("http://api.bilibili.com/x/web-interface/index/icon"))
go engine.Run(":18080")
defer func() {
engine.Server().Shutdown(context.TODO())
}()
time.Sleep(time.Second)
wg := sync.WaitGroup{}
for i := 0; i < 20; i++ {
wg.Add(1)
go func() {
defer wg.Done()
req, err := http.NewRequest("GET", "http://127.0.0.1:18080/icon", nil)
assert.NoError(t, err)
req.Host = "api.bilibili.com"
resp, err := http.DefaultClient.Do(req)
assert.NoError(t, err)
defer resp.Body.Close()
assert.Equal(t, 200, resp.StatusCode)
}()
}
wg.Wait()
}
func TestZoneProxy(t *testing.T) {
engine := bm.Default()
engine.GET("/icon", NewZoneProxy("sh004", "http://api.bilibili.com/x/web-interface/index/icon"), func(ctx *bm.Context) {
ctx.AbortWithStatus(500)
})
engine.GET("/icon2", NewZoneProxy("none", "http://api.bilibili.com/x/web-interface/index/icon2"), func(ctx *bm.Context) {
ctx.AbortWithStatus(200)
})
ug := engine.Group("/update", NewZoneProxy("sh004", "http://api.bilibili.com"))
ug.POST("/name", func(ctx *bm.Context) {
ctx.AbortWithStatus(500)
})
ug.POST("/sign", func(ctx *bm.Context) {
ctx.AbortWithStatus(500)
})
go engine.Run(":18080")
defer func() {
engine.Server().Shutdown(context.TODO())
}()
time.Sleep(time.Second)
req, err := http.NewRequest("GET", "http://127.0.0.1:18080/icon", nil)
assert.NoError(t, err)
req.Host = "api.bilibili.com"
req.Header.Set("X-BILI-SLB", "shjd-out-slb")
resp, err := http.DefaultClient.Do(req)
assert.NoError(t, err)
defer resp.Body.Close()
assert.Equal(t, 200, resp.StatusCode)
req.URL.Path = "/icon2"
resp, err = http.DefaultClient.Do(req)
assert.NoError(t, err)
defer resp.Body.Close()
assert.Equal(t, 200, resp.StatusCode)
req.URL.Path = "/update/name"
resp, err = http.DefaultClient.Do(req)
assert.NoError(t, err)
defer resp.Body.Close()
assert.Equal(t, 200, resp.StatusCode)
req.URL.Path = "/update/sign"
resp, err = http.DefaultClient.Do(req)
assert.NoError(t, err)
defer resp.Body.Close()
assert.Equal(t, 200, resp.StatusCode)
}
func BenchmarkProxy(b *testing.B) {
engine := bm.Default()
engine.GET("/icon", NewAlways("http://api.bilibili.com/x/web-interface/index/icon"))
go engine.Run(":18080")
defer func() {
engine.Server().Shutdown(context.TODO())
}()
time.Sleep(time.Second)
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
req, err := http.NewRequest("GET", "http://127.0.0.1:18080/icon", nil)
assert.NoError(b, err)
req.Host = "api.bilibili.com"
resp, err := http.DefaultClient.Do(req)
assert.NoError(b, err)
defer resp.Body.Close()
assert.Equal(b, 200, resp.StatusCode)
}
})
}

View File

@@ -0,0 +1,53 @@
package(default_visibility = ["//visibility:public"])
load(
"@io_bazel_rules_go//go:def.bzl",
"go_test",
"go_library",
)
go_test(
name = "go_default_test",
srcs = ["limit_test.go"],
embed = [":go_default_library"],
rundir = ".",
tags = ["automanaged"],
deps = ["//library/net/http/blademaster:go_default_library"],
)
go_library(
name = "go_default_library",
srcs = ["limit.go"],
importpath = "go-common/library/net/http/blademaster/middleware/rate",
tags = ["automanaged"],
visibility = ["//visibility:public"],
deps = [
"//library/log:go_default_library",
"//library/net/http/blademaster:go_default_library",
"//vendor/golang.org/x/time/rate:go_default_library",
],
)
go_test(
name = "go_default_xtest",
srcs = ["example_test.go"],
tags = ["automanaged"],
deps = [
"//library/net/http/blademaster:go_default_library",
"//library/net/http/blademaster/middleware/rate:go_default_library",
],
)
filegroup(
name = "package-srcs",
srcs = glob(["**"]),
tags = ["automanaged"],
visibility = ["//visibility:private"],
)
filegroup(
name = "all-srcs",
srcs = [":package-srcs"],
tags = ["automanaged"],
visibility = ["//visibility:public"],
)

View File

@@ -0,0 +1,5 @@
### business/blademaster/rate
##### Version 1.0.0
1. 完成基本功能与测试

View File

@@ -0,0 +1,6 @@
# Author
lintnaghui
caoguoliang
# Reviewer
maojian

View File

@@ -0,0 +1,9 @@
# See the OWNERS docs at https://go.k8s.io/owners
approvers:
- caoguoliang
- lintnaghui
reviewers:
- caoguoliang
- lintnaghui
- maojian

View File

@@ -0,0 +1,13 @@
#### business/blademaster/rate
##### 项目简介
blademaster 的 rate middleware主要用于限制内部调用的频率
##### 编译环境
- **请只用 Golang v1.8.x 以上版本编译执行**
##### 依赖包
- No other dependency

View File

@@ -0,0 +1,28 @@
package rate_test
import (
"go-common/library/net/http/blademaster"
"go-common/library/net/http/blademaster/middleware/rate"
)
// This example create a rate middleware instance and attach to a blademaster engine,
// it will protect '/ping' API frequency with specified policy.
// If any internal service who requests this API more frequently than 1 req/second,
// a StatusTooManyRequests error will be raised.
func Example() {
lim := rate.New(&rate.Config{
URLs: map[string]*rate.Limit{
"/ping": &rate.Limit{Limit: 1, Burst: 2},
},
Apps: map[string]*rate.Limit{
"a-secret-app-key": &rate.Limit{Limit: 1, Burst: 2},
},
})
engine := blademaster.Default()
engine.Use(lim)
engine.GET("/ping", func(c *blademaster.Context) {
c.String(200, "%s", "pong")
})
engine.Run(":18080")
}

Some files were not shown because too many files have changed in this diff Show More