semihalev / sdns
1
package server
2

3
import (
4
	"bufio"
5
	"context"
6
	"io"
7
	l "log"
8
	"net/http"
9
	"strings"
10
	"sync"
11
	"time"
12

13
	"github.com/miekg/dns"
14

15
	"github.com/semihalev/log"
16
	"github.com/semihalev/sdns/config"
17
	"github.com/semihalev/sdns/middleware"
18
	"github.com/semihalev/sdns/mock"
19
	"github.com/semihalev/sdns/server/doh"
20
)
21

22
// Server type
23
type Server struct {
24
	addr           string
25
	tlsAddr        string
26
	dohAddr        string
27
	tlsCertificate string
28
	tlsPrivateKey  string
29

30
	chainPool sync.Pool
31
}
32

33
// New return new server
34
func New(cfg *config.Config) *Server {
35 1
	if cfg.Bind == "" {
36 1
		cfg.Bind = ":53"
37
	}
38

39 1
	server := &Server{
40 1
		addr:           cfg.Bind,
41 1
		tlsAddr:        cfg.BindTLS,
42 1
		dohAddr:        cfg.BindDOH,
43 1
		tlsCertificate: cfg.TLSCertificate,
44 1
		tlsPrivateKey:  cfg.TLSPrivateKey,
45
	}
46

47 1
	server.chainPool.New = func() interface{} {
48 1
		return middleware.NewChain(middleware.Handlers())
49
	}
50

51 1
	return server
52
}
53

54
// ServeDNS implements the Handle interface.
55
func (s *Server) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
56 1
	ch := s.chainPool.Get().(*middleware.Chain)
57

58 1
	ch.Reset(w, r)
59

60 1
	ch.Next(context.Background())
61

62 1
	s.chainPool.Put(ch)
63
}
64

65
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
66 1
	handle := func(req *dns.Msg) *dns.Msg {
67 1
		mw := mock.NewWriter("tcp", r.RemoteAddr)
68 1
		s.ServeDNS(mw, req)
69

70 1
		if !mw.Written() {
71 1
			return nil
72
		}
73

74 1
		return mw.Msg()
75
	}
76

77 1
	var handlerFn func(http.ResponseWriter, *http.Request)
78 1
	if r.Method == http.MethodGet && r.URL.Query().Get("dns") == "" {
79 1
		handlerFn = doh.HandleJSON(handle)
80 1
	} else {
81 1
		handlerFn = doh.HandleWireFormat(handle)
82
	}
83

84 1
	handlerFn(w, r)
85
}
86

87
// Run listen the services
88
func (s *Server) Run() {
89 1
	go s.ListenAndServeDNS("udp")
90 1
	go s.ListenAndServeDNS("tcp")
91 1
	go s.ListenAndServeDNSTLS()
92 1
	go s.ListenAndServeHTTPTLS()
93
}
94

95
// ListenAndServeDNS Starts a server on address and network specified Invoke handler
96
// for incoming queries.
97
func (s *Server) ListenAndServeDNS(network string) {
98 1
	log.Info("DNS server listening...", "net", network, "addr", s.addr)
99

100 1
	server := &dns.Server{
101 1
		Addr:          s.addr,
102 1
		Net:           network,
103 1
		Handler:       s,
104 1
		MaxTCPQueries: 2048,
105 1
		ReusePort:     true,
106
	}
107

108 1
	if err := server.ListenAndServe(); err != nil {
109 1
		log.Error("DNS listener failed", "net", network, "addr", s.addr, "error", err.Error())
110
	}
111
}
112

113
// ListenAndServeDNSTLS acts like http.ListenAndServeTLS
114
func (s *Server) ListenAndServeDNSTLS() {
115 1
	if s.tlsAddr == "" {
116 1
		return
117
	}
118

119 1
	log.Info("DNS server listening...", "net", "tcp-tls", "addr", s.tlsAddr)
120

121 1
	if err := dns.ListenAndServeTLS(s.tlsAddr, s.tlsCertificate, s.tlsPrivateKey, s); err != nil {
122 1
		log.Error("DNS listener failed", "net", "tcp-tls", "addr", s.tlsAddr, "error", err.Error())
123
	}
124
}
125

126
// ListenAndServeHTTPTLS acts like http.ListenAndServeTLS
127
func (s *Server) ListenAndServeHTTPTLS() {
128 1
	if s.dohAddr == "" {
129 1
		return
130
	}
131

132 1
	log.Info("DNS server listening...", "net", "https", "addr", s.dohAddr)
133

134 1
	logReader, logWriter := io.Pipe()
135 1
	go readlogs(logReader)
136

137 1
	srv := &http.Server{
138 1
		Addr:         s.dohAddr,
139 1
		Handler:      s,
140 1
		ReadTimeout:  30 * time.Second,
141 1
		WriteTimeout: 30 * time.Second,
142 1
		ErrorLog:     l.New(logWriter, "", 0),
143
	}
144

145 1
	if err := srv.ListenAndServeTLS(s.tlsCertificate, s.tlsPrivateKey); err != nil {
146 1
		log.Error("DNSs listener failed", "net", "https", "addr", s.dohAddr, "error", err.Error())
147
	}
148
}
149

150
func readlogs(rd io.Reader) {
151 1
	buf := bufio.NewReader(rd)
152
	for {
153 1
		line, err := buf.ReadBytes('\n')
154 1
		if err != nil {
155 0
			continue
156
		}
157

158 1
		parts := strings.SplitN(string(line[:len(line)-1]), " ", 2)
159 1
		if len(parts) > 1 {
160 1
			log.Warn("Client http socket failed", "net", "https", "error", parts[1])
161
		}
162
	}
163
}

Read our documentation on viewing source code .

Loading