Implement embedding and chat model factories for multiple providers
This commit is contained in:
parent
203ba4a45c
commit
bc87753f2d
|
|
@ -1,7 +1,6 @@
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from concurrent import futures
|
from concurrent import futures
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import brunix_pb2
|
import brunix_pb2
|
||||||
import brunix_pb2_grpc
|
import brunix_pb2_grpc
|
||||||
|
|
@ -13,20 +12,10 @@ from utils.llm_factory import create_chat_model
|
||||||
from utils.emb_factory import create_embedding_model
|
from utils.emb_factory import create_embedding_model
|
||||||
from graph import build_graph
|
from graph import build_graph
|
||||||
|
|
||||||
# PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
|
||||||
# if str(PROJECT_ROOT) not in sys.path:
|
|
||||||
# sys.path.insert(0, str(PROJECT_ROOT))
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
logger = logging.getLogger("brunix-engine")
|
logger = logging.getLogger("brunix-engine")
|
||||||
|
|
||||||
|
|
||||||
def _provider_kwargs(provider: str, base_url: str) -> dict[str, Any]:
|
|
||||||
if provider == "ollama":
|
|
||||||
return {"base_url": base_url}
|
|
||||||
return {}
|
|
||||||
|
|
||||||
|
|
||||||
class BrunixEngine(brunix_pb2_grpc.AssistanceEngineServicer):
|
class BrunixEngine(brunix_pb2_grpc.AssistanceEngineServicer):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.llm = create_chat_model(
|
self.llm = create_chat_model(
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue