feat: support radius accounting request (#2362)

* feat: add radius server

* feat: parse org from packet

* feat: add comment

* feat: support radius accounting

* feat: change log

* feat: add copyright
This commit is contained in:
haiwu 2023-09-26 22:48:00 +08:00 committed by GitHub
parent 981908b0b6
commit 6fe5c44c1c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 229 additions and 38 deletions

View File

@ -21,6 +21,7 @@ isDemoMode = false
batchSize = 100 batchSize = 100
ldapServerPort = 389 ldapServerPort = 389
radiusServerPort = 1812 radiusServerPort = 1812
radiusSecret = "secret"
quota = {"organization": -1, "user": -1, "application": -1, "provider": -1} quota = {"organization": -1, "user": -1, "application": -1, "provider": -1}
logConfig = {"filename": "logs/casdoor.log", "maxdays":99999, "perm":"0770"} logConfig = {"filename": "logs/casdoor.log", "maxdays":99999, "perm":"0770"}
initDataFile = "./init_data.json" initDataFile = "./init_data.json"

View File

@ -316,6 +316,11 @@ func (a *Ormer) createTable() {
panic(err) panic(err)
} }
err = a.Engine.Sync2(new(RadiusAccounting))
if err != nil {
panic(err)
}
err = a.Engine.Sync2(new(PermissionRule)) err = a.Engine.Sync2(new(PermissionRule))
if err != nil { if err != nil {
panic(err) panic(err)

124
object/radius.go Normal file
View File

@ -0,0 +1,124 @@
package object
import (
"fmt"
"time"
"github.com/casdoor/casdoor/util"
"github.com/xorm-io/core"
)
// https://www.cisco.com/c/en/us/td/docs/ios-xml/ios/sec_usr_radatt/configuration/xe-16/sec-usr-radatt-xe-16-book/sec-rad-ov-ietf-attr.html
// https://support.huawei.com/enterprise/zh/doc/EDOC1000178159/35071f9a
type RadiusAccounting struct {
Owner string `xorm:"varchar(100) notnull pk" json:"owner"`
Name string `xorm:"varchar(100) notnull pk" json:"name"`
CreatedTime time.Time `json:"createdTime"`
Username string `xorm:"index" json:"username"`
ServiceType int64 `json:"serviceType"` // e.g. LoginUser (1)
NasId string `json:"nasId"` // String identifying the network access server originating the Access-Request.
NasIpAddr string `json:"nasIpAddr"` // e.g. "192.168.0.10"
NasPortId string `json:"nasPortId"` // Contains a text string which identifies the port of the NAS that is authenticating the user. e.g."eth.0"
NasPortType int64 `json:"nasPortType"` // Indicates the type of physical port the network access server is using to authenticate the user. e.g.Ethernet15
NasPort int64 `json:"nasPort"` // Indicates the physical port number of the network access server that is authenticating the user. e.g. 233
FramedIpAddr string `json:"framedIpAddr"` // Indicates the IP address to be configured for the user by sending the IP address of a user to the RADIUS server.
FramedIpNetmask string `json:"framedIpNetmask"` // Indicates the IP netmask to be configured for the user when the user is using a device on a network.
AcctSessionId string `xorm:"index" json:"acctSessionId"`
AcctSessionTime int64 `json:"acctSessionTime"` // Indicates how long (in seconds) the user has received service.
AcctInputTotal int64 `json:"acctInputTotal"`
AcctOutputTotal int64 `json:"acctOutputTotal"`
AcctInputPackets int64 `json:"acctInputPackets"` // Indicates how many packets have been received from the port over the course of this service being provided to a framed user.
AcctOutputPackets int64 `json:"acctOutputPackets"` // Indicates how many packets have been sent to the port in the course of delivering this service to a framed user.
AcctTerminateCause int64 `json:"acctTerminateCause"` // e.g. Lost-Carrier (2)
LastUpdate time.Time `json:"lastUpdate"`
AcctStartTime time.Time `xorm:"index" json:"acctStartTime"`
AcctStopTime time.Time `xorm:"index" json:"acctStopTime"`
}
func (ra *RadiusAccounting) GetId() string {
return util.GetId(ra.Owner, ra.Name)
}
func getRadiusAccounting(owner, name string) (*RadiusAccounting, error) {
if owner == "" || name == "" {
return nil, nil
}
ra := RadiusAccounting{Owner: owner, Name: name}
existed, err := ormer.Engine.Get(&ra)
if err != nil {
return nil, err
}
if existed {
return &ra, nil
} else {
return nil, nil
}
}
func getPaginationRadiusAccounting(owner, field, value, sortField, sortOrder string, offset, limit int) ([]*RadiusAccounting, error) {
ras := []*RadiusAccounting{}
session := GetSession(owner, offset, limit, field, value, sortField, sortOrder)
err := session.Find(&ras)
if err != nil {
return ras, err
}
return ras, nil
}
func GetRadiusAccounting(id string) (*RadiusAccounting, error) {
owner, name := util.GetOwnerAndNameFromId(id)
return getRadiusAccounting(owner, name)
}
func GetRadiusAccountingBySessionId(sessionId string) (*RadiusAccounting, error) {
ras, err := getPaginationRadiusAccounting("", "acct_session_id", sessionId, "created_time", "desc", 0, 1)
if err != nil {
return nil, err
}
if len(ras) == 0 {
return nil, nil
}
return ras[0], nil
}
func AddRadiusAccounting(ra *RadiusAccounting) error {
_, err := ormer.Engine.Insert(ra)
return err
}
func DeleteRadiusAccounting(ra *RadiusAccounting) error {
_, err := ormer.Engine.ID(core.PK{ra.Owner, ra.Name}).Delete(&RadiusAccounting{})
return err
}
func UpdateRadiusAccounting(id string, ra *RadiusAccounting) error {
owner, name := util.GetOwnerAndNameFromId(id)
_, err := ormer.Engine.ID(core.PK{owner, name}).Update(ra)
return err
}
func InterimUpdateRadiusAccounting(oldRa *RadiusAccounting, newRa *RadiusAccounting, stop bool) error {
if oldRa.AcctSessionId != newRa.AcctSessionId {
return fmt.Errorf("AcctSessionId is not equal, newRa = %s, oldRa = %s", newRa.AcctSessionId, oldRa.AcctSessionId)
}
oldRa.AcctInputTotal = newRa.AcctInputTotal
oldRa.AcctOutputTotal = newRa.AcctOutputTotal
oldRa.AcctInputPackets = newRa.AcctInputPackets
oldRa.AcctOutputPackets = newRa.AcctOutputPackets
oldRa.AcctSessionTime = newRa.AcctSessionTime
if stop {
oldRa.AcctStopTime = newRa.AcctStopTime
if oldRa.AcctStopTime.IsZero() {
oldRa.AcctStopTime = time.Now()
}
oldRa.AcctTerminateCause = newRa.AcctTerminateCause
} else {
oldRa.LastUpdate = time.Now()
}
return UpdateRadiusAccounting(oldRa.GetId(), oldRa)
}

View File

@ -15,24 +15,27 @@
package radius package radius
import ( import (
"fmt"
"log" "log"
"github.com/casdoor/casdoor/conf" "github.com/casdoor/casdoor/conf"
"github.com/casdoor/casdoor/object" "github.com/casdoor/casdoor/object"
"layeh.com/radius" "layeh.com/radius"
"layeh.com/radius/rfc2865" "layeh.com/radius/rfc2865"
"layeh.com/radius/rfc2866"
) )
// https://support.huawei.com/enterprise/zh/doc/EDOC1000178159/35071f9a#tab_3 // https://support.huawei.com/enterprise/zh/doc/EDOC1000178159/35071f9a#tab_3
func StartRadiusServer() { func StartRadiusServer() {
secret := conf.GetConfigString("radiusSecret")
server := radius.PacketServer{ server := radius.PacketServer{
Addr: "0.0.0.0:" + conf.GetConfigString("radiusServerPort"), Addr: "0.0.0.0:" + conf.GetConfigString("radiusServerPort"),
Handler: radius.HandlerFunc(handlerRadius), Handler: radius.HandlerFunc(handlerRadius),
SecretSource: radius.StaticSecretSource([]byte(`secret`)), SecretSource: radius.StaticSecretSource([]byte(secret)),
} }
log.Printf("Starting Radius server on %s", server.Addr) log.Printf("Starting Radius server on %s", server.Addr)
if err := server.ListenAndServe(); err != nil { if err := server.ListenAndServe(); err != nil {
log.Printf("StartRadiusServer() failed, err = %s", err.Error()) log.Printf("StartRadiusServer() failed, err = %v", err)
} }
} }
@ -40,6 +43,8 @@ func handlerRadius(w radius.ResponseWriter, r *radius.Request) {
switch r.Code { switch r.Code {
case radius.CodeAccessRequest: case radius.CodeAccessRequest:
handleAccessRequest(w, r) handleAccessRequest(w, r)
case radius.CodeAccountingRequest:
handleAccountingRequest(w, r)
default: default:
log.Printf("radius message, code = %d", r.Code) log.Printf("radius message, code = %d", r.Code)
} }
@ -48,20 +53,57 @@ func handlerRadius(w radius.ResponseWriter, r *radius.Request) {
func handleAccessRequest(w radius.ResponseWriter, r *radius.Request) { func handleAccessRequest(w radius.ResponseWriter, r *radius.Request) {
username := rfc2865.UserName_GetString(r.Packet) username := rfc2865.UserName_GetString(r.Packet)
password := rfc2865.UserPassword_GetString(r.Packet) password := rfc2865.UserPassword_GetString(r.Packet)
organization := parseOrganization(r.Packet) organization := rfc2865.Class_GetString(r.Packet)
code := radius.CodeAccessAccept log.Printf("handleAccessRequest() username=%v, org=%v, password=%v", username, organization, password)
log.Printf("username=%v, password=%v, code=%v, org=%v", username, password, code, organization)
if organization == "" { if organization == "" {
code = radius.CodeAccessReject w.Write(r.Response(radius.CodeAccessReject))
w.Write(r.Response(code))
return return
} }
_, msg := object.CheckUserPassword(organization, username, password, "en") _, msg := object.CheckUserPassword(organization, username, password, "en")
if msg != "" { if msg != "" {
code = radius.CodeAccessReject w.Write(r.Response(radius.CodeAccessReject))
w.Write(r.Response(code))
return return
} }
w.Write(r.Response(code)) w.Write(r.Response(radius.CodeAccessAccept))
}
func handleAccountingRequest(w radius.ResponseWriter, r *radius.Request) {
statusType := rfc2866.AcctStatusType_Get(r.Packet)
username := rfc2865.UserName_GetString(r.Packet)
organization := rfc2865.Class_GetString(r.Packet)
log.Printf("handleAccountingRequest() username=%v, org=%v, statusType=%v", username, organization, statusType)
w.Write(r.Response(radius.CodeAccountingResponse))
var err error
defer func() {
if err != nil {
log.Printf("handleAccountingRequest() failed, err = %v", err)
}
}()
switch statusType {
case rfc2866.AcctStatusType_Value_Start:
// Start an accounting session
ra := GetAccountingFromRequest(r)
err = object.AddRadiusAccounting(ra)
case rfc2866.AcctStatusType_Value_InterimUpdate, rfc2866.AcctStatusType_Value_Stop:
// Interim update to an accounting session | Stop an accounting session
var (
newRa = GetAccountingFromRequest(r)
oldRa *object.RadiusAccounting
)
oldRa, err = object.GetRadiusAccountingBySessionId(newRa.AcctSessionId)
if err != nil {
return
}
if oldRa == nil {
if err = object.AddRadiusAccounting(newRa); err != nil {
return
}
}
stop := statusType == rfc2866.AcctStatusType_Value_Stop
err = object.InterimUpdateRadiusAccounting(oldRa, newRa, stop)
case rfc2866.AcctStatusType_Value_AccountingOn, rfc2866.AcctStatusType_Value_AccountingOff:
// By default, no Accounting-On or Accounting-Off messages are sent (no acct-on-off).
default:
err = fmt.Errorf("unsupport statusType = %v", statusType)
}
} }

View File

@ -29,11 +29,7 @@ func TestAccessRequestRejected(t *testing.T) {
packet := radius.New(radius.CodeAccessRequest, []byte(`secret`)) packet := radius.New(radius.CodeAccessRequest, []byte(`secret`))
rfc2865.UserName_SetString(packet, "admin") rfc2865.UserName_SetString(packet, "admin")
rfc2865.UserPassword_SetString(packet, "12345") rfc2865.UserPassword_SetString(packet, "12345")
vsa, err := radius.NewVendorSpecific(OrganizationVendorID, []byte("built-in")) rfc2865.Class_SetString(packet, "built-in")
if err != nil {
t.Fatal(err)
}
packet.Add(rfc2865.VendorSpecific_Type, vsa)
response, err := radius.Exchange(context.Background(), packet, "localhost:1812") response, err := radius.Exchange(context.Background(), packet, "localhost:1812")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -47,11 +43,7 @@ func TestAccessRequestAccepted(t *testing.T) {
packet := radius.New(radius.CodeAccessRequest, []byte(`secret`)) packet := radius.New(radius.CodeAccessRequest, []byte(`secret`))
rfc2865.UserName_SetString(packet, "admin") rfc2865.UserName_SetString(packet, "admin")
rfc2865.UserPassword_SetString(packet, "123") rfc2865.UserPassword_SetString(packet, "123")
vsa, err := radius.NewVendorSpecific(OrganizationVendorID, []byte("built-in")) rfc2865.Class_SetString(packet, "built-in")
if err != nil {
t.Fatal(err)
}
packet.Add(rfc2865.VendorSpecific_Type, vsa)
response, err := radius.Exchange(context.Background(), packet, "localhost:1812") response, err := radius.Exchange(context.Background(), packet, "localhost:1812")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)

View File

@ -15,26 +15,53 @@
package radius package radius
import ( import (
"fmt"
"time"
"github.com/casdoor/casdoor/object"
"github.com/casdoor/casdoor/util"
"layeh.com/radius" "layeh.com/radius"
"layeh.com/radius/rfc2865" "layeh.com/radius/rfc2865"
"layeh.com/radius/rfc2866"
"layeh.com/radius/rfc2869"
) )
const ( func GetAccountingFromRequest(r *radius.Request) *object.RadiusAccounting {
OrganizationVendorID = uint32(100) acctInputOctets := int(rfc2866.AcctInputOctets_Get(r.Packet))
) acctInputGigawords := int(rfc2869.AcctInputGigawords_Get(r.Packet))
acctOutputOctets := int(rfc2866.AcctOutputOctets_Get(r.Packet))
acctOutputGigawords := int(rfc2869.AcctOutputGigawords_Get(r.Packet))
organization := rfc2865.Class_GetString(r.Packet)
getAcctStartTime := func(sessionTime int) time.Time {
m, _ := time.ParseDuration(fmt.Sprintf("-%ds", sessionTime))
return time.Now().Add(m)
}
ra := &object.RadiusAccounting{
Owner: organization,
Name: "ra_" + util.GenerateId()[:6],
CreatedTime: time.Now(),
func parseOrganization(p *radius.Packet) string { Username: rfc2865.UserName_GetString(r.Packet),
for _, avp := range p.Attributes { ServiceType: int64(rfc2865.ServiceType_Get(r.Packet)),
if avp.Type == rfc2865.VendorSpecific_Type {
attr := avp.Attribute NasId: rfc2865.NASIdentifier_GetString(r.Packet),
vendorId, value, err := radius.VendorSpecific(attr) NasIpAddr: rfc2865.NASIPAddress_Get(r.Packet).String(),
if err != nil { NasPortId: rfc2869.NASPortID_GetString(r.Packet),
return "" NasPortType: int64(rfc2865.NASPortType_Get(r.Packet)),
NasPort: int64(rfc2865.NASPort_Get(r.Packet)),
FramedIpAddr: rfc2865.FramedIPAddress_Get(r.Packet).String(),
FramedIpNetmask: rfc2865.FramedIPNetmask_Get(r.Packet).String(),
AcctSessionId: rfc2866.AcctSessionID_GetString(r.Packet),
AcctSessionTime: int64(rfc2866.AcctSessionTime_Get(r.Packet)),
AcctInputTotal: int64(acctInputOctets) + int64(acctInputGigawords)*4*1024*1024*1024,
AcctOutputTotal: int64(acctOutputOctets) + int64(acctOutputGigawords)*4*1024*1024*1024,
AcctInputPackets: int64(rfc2866.AcctInputPackets_Get(r.Packet)),
AcctOutputPackets: int64(rfc2866.AcctInputPackets_Get(r.Packet)),
AcctStartTime: getAcctStartTime(int(rfc2866.AcctSessionTime_Get(r.Packet))),
AcctTerminateCause: int64(rfc2866.AcctTerminateCause_Get(r.Packet)),
LastUpdate: time.Now(),
} }
if vendorId == OrganizationVendorID { return ra
return string(value)
}
}
}
return ""
} }