all repos — quartzgun @ 483e59e2b26f5797f99336bbe325411cf0dcaf77

lightweight web framework in go

router/router.go (raw)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
package router

import (
  "net/http"
  "html/template"
  "regexp"
  "log"
  "strconv"
  "strings"
  "path"
  "os"
  "errors"
  "context"
)

type Router struct {
  /* This is the template for error pages */
  Fallback template.Template
  /* Routes are only filled by using the appropriate methods. */
  routes []Route
  /* StaticPaths can be filled from outside when constructing the Router.
   * key = uri
   * value = file path
   */
  StaticPaths map[string]string
}

type Route struct {
  path *regexp.Regexp
  handlerMap map[string]http.Handler
}

func (self *Router) Get(path string, h http.Handler) {
  self.AddRoute("GET", path, h)
}

func (self *Router) Post(path string, h http.Handler) {
  self.AddRoute("POST", path, h)
}

func (self *Router) Put(path string,  h http.Handler) {
  self.AddRoute("PUT", path, h)
}

func (self *Router) Delete(path string, h http.Handler) {
  self.AddRoute("DELETE", path, h)
}

func (self *Router) AddRoute(method string, path string, h http.Handler) {

  exactPath := regexp.MustCompile("^" + path + "$")

  /* If the route already exists, try to add this method to the ServerTask map. */
  for _, r := range self.routes {
    if r.path == exactPath {
      r.handlerMap[method] = h
      return
    }
  }

  /* Otherwise add a new route */
  self.routes = append(self.routes, Route{
    path: exactPath,
    handlerMap: map[string]http.Handler{method: h},
  })

}

func (self *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
  /* Show the 500 error page if we panic */
  defer func() {
    if r := recover(); r != nil {
      log.Println("ERROR:", r)
      self.ErrorPage(w, req, 500, "There was an error on the server.")
    }
  }()

  /* If the request matches any our StaticPaths, try to serve a file. */
  for uri, dir := range self.StaticPaths {
    if req.Method == "GET" && strings.HasPrefix(req.URL.Path, uri) {
      restOfUri := strings.TrimPrefix(req.URL.Path, uri)
      p := path.Join(dir, restOfUri)
      p = path.Clean(p)

      /* If the file exists, try to serve it. */
      info, err := os.Stat(p);
      if err == nil && !info.IsDir() {
        http.ServeFile(w, req, p)
      /* Handle the common errors */
      } else if errors.Is(err, os.ErrNotExist) || errors.Is(err, os.ErrExist) {
        self.ErrorPage(w, req, 404, "The requested file does not exist")
      } else if errors.Is(err, os.ErrPermission) || info.IsDir() {
        self.ErrorPage(w, req, 403, "Access forbidden")
      /* If it's some weird error, serve a 500. */
      } else {
        self.ErrorPage(w, req, 500, "Internal server error")
      }

      return
    }
  }

  /* Otherwise, this is a normal route */
  for _, r := range self.routes {

    /* Pull the params out of the regex;
     * If the path doesn't match the regex, params will be nil.
     */
    params := r.Match(req)
    if params == nil {
      continue
    }
    for method, handler := range r.handlerMap {
      if method == req.Method {
        /* Parse the form and add the params to the context */
        req.ParseForm()
        ProcessParams(req, params)
        /* handle the request! */
        handler.ServeHTTP(w, req);
        return
      }
    }
  }
  self.ErrorPage(w, req, 404, "The page you requested does not exist!")
}

/*******************
 * Utility Methods *
 *******************/

func ProcessParams(req *http.Request, params map[string]string) {
  *req = *req.WithContext(context.WithValue(req.Context(), "params", params))
}

func (self *Route) Match(r *http.Request) map[string]string {
  match := self.path.FindStringSubmatch(r.URL.Path)
  if match == nil {
    return nil
  }

  params := map[string]string{}
  groupNames := self.path.SubexpNames()

  for i, group := range match {
    params[groupNames[i]] = group
  }

  return params
}

func (self *Router) ErrorPage(w http.ResponseWriter, req *http.Request, code int, errMsg string) {
  w.WriteHeader(code)
  params := map[string]string{
    "ErrorCode": strconv.Itoa(code),
    "ErrorMessage": errMsg,
  }
  ProcessParams(req, params)
  self.Fallback.Execute(w, req)
}