1
package api
2

3
import (
4
	"context"
5
	"encoding/base64"
6
	"net/http"
7
	"os"
8
	"strings"
9

10
	"github.com/gin-contrib/cors"
11
	"github.com/gin-contrib/pprof"
12
	"github.com/gin-gonic/gin"
13
	"github.com/miekg/dns"
14
	"github.com/prometheus/client_golang/prometheus/promhttp"
15
	"github.com/semihalev/log"
16
	"github.com/semihalev/sdns/config"
17
	"github.com/semihalev/sdns/dnsutil"
18
	"github.com/semihalev/sdns/middleware"
19
	"github.com/semihalev/sdns/middleware/blocklist"
20
)
21

22
// API type
23
type API struct {
24
	host      string
25
	blocklist *blocklist.BlockList
26
}
27

28
var debugpprof bool
29

30
func init() {
31 1
	gin.SetMode(gin.ReleaseMode)
32

33 1
	_, debugpprof = os.LookupEnv("SDNS_PPROF")
34
}
35

36
// New return new api
37
func New(cfg *config.Config) *API {
38 1
	var bl *blocklist.BlockList
39

40 1
	b := middleware.Get("blocklist")
41 1
	if b != nil {
42 1
		bl = b.(*blocklist.BlockList)
43
	}
44

45 1
	return &API{
46 1
		host:      cfg.API,
47 1
		blocklist: bl,
48
	}
49
}
50

51
func (a *API) existsBlock(c *gin.Context) {
52 1
	c.JSON(http.StatusOK, gin.H{"exists": a.blocklist.Exists(c.Param("key"))})
53
}
54

55
func (a *API) getBlock(c *gin.Context) {
56 1
	if ok, _ := a.blocklist.Get(dns.Fqdn(c.Param("key"))); !ok {
57 1
		c.JSON(http.StatusNotFound, gin.H{"error": c.Param("key") + " not found"})
58 1
	} else {
59 1
		c.JSON(http.StatusOK, gin.H{"success": ok})
60
	}
61
}
62

63
func (a *API) removeBlock(c *gin.Context) {
64 1
	a.blocklist.Remove(dns.Fqdn(c.Param("key")))
65 1
	c.JSON(http.StatusOK, gin.H{"success": true})
66
}
67

68
func (a *API) setBlock(c *gin.Context) {
69 1
	a.blocklist.Set(dns.Fqdn(c.Param("key")))
70 1
	c.JSON(http.StatusOK, gin.H{"success": true})
71
}
72

73
func (a *API) metrics(c *gin.Context) {
74 1
	promhttp.Handler().ServeHTTP(c.Writer, c.Request)
75
}
76

77
func (a *API) purge(c *gin.Context) {
78 1
	qtype := strings.ToUpper(c.Param("qtype"))
79 1
	qname := dns.Fqdn(c.Param("qname"))
80

81 1
	bqname := base64.StdEncoding.EncodeToString([]byte(qtype + ":" + qname))
82

83 1
	req := new(dns.Msg)
84 1
	req.SetQuestion(dns.Fqdn(bqname), dns.TypeNULL)
85 1
	req.Question[0].Qclass = dns.ClassCHAOS
86

87 1
	_, _ = dnsutil.ExchangeInternal(context.Background(), req)
88

89 1
	c.JSON(http.StatusOK, gin.H{"success": true})
90
}
91

92
// Run API server
93
func (a *API) Run() {
94 1
	if a.host == "" {
95 1
		return
96
	}
97

98 1
	r := gin.Default()
99 1
	r.Use(cors.Default())
100

101 1
	if debugpprof {
102 1
		pprof.Register(r)
103
	}
104

105 1
	if a.blocklist != nil {
106 1
		block := r.Group("/api/v1/block")
107
		{
108 1
			block.GET("/exists/:key", a.existsBlock)
109 1
			block.GET("/get/:key", a.getBlock)
110 1
			block.GET("/remove/:key", a.removeBlock)
111 1
			block.GET("/set/:key", a.setBlock)
112
		}
113
	}
114

115 1
	r.GET("/api/v1/purge/:qname/:qtype", a.purge)
116 1
	r.GET("/metrics", a.metrics)
117

118 1
	go func() {
119 1
		if err := r.Run(a.host); err != nil {
120 1
			log.Error("Start API server failed", "error", err.Error())
121
		}
122
	}()
123

124 1
	log.Info("API server listening...", "addr", a.host)
125
}

Read our documentation on viewing source code .

Loading